"app/git@developer.sourcefind.cn:orangecat/ollama.git" did not exist on "e64ef69e348b6f4396543c7a5e1c196d6a5287a0"
Unverified Commit 24d59c79 authored by fxmarty's avatar fxmarty Committed by GitHub
Browse files

Use `torch.bool` instead of `torch.int64` for non-persistant causal mask buffer (#29241)

use torch.bool instead of torch.int64
parent 7c4995f9
......@@ -810,8 +810,11 @@ class GemmaModel(GemmaPreTrainedModel):
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()
......
......@@ -811,7 +811,9 @@ class LlamaPreTrainedModel(PreTrainedModel):
)
if max_cache_len > self.model.causal_mask.shape[-1] or self.device != self.model.causal_mask.device:
causal_mask = torch.full((max_cache_len, max_cache_len), fill_value=1, device=self.device)
causal_mask = torch.full(
(max_cache_len, max_cache_len), fill_value=True, device=self.device, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
for layer in self.model.layers:
......@@ -919,8 +921,11 @@ class LlamaModel(LlamaPreTrainedModel):
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
causal_mask = torch.full((config.max_position_embeddings, config.max_position_embeddings), fill_value=1)
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask = torch.full(
(config.max_position_embeddings, config.max_position_embeddings), fill_value=True, dtype=torch.bool
)
self.register_buffer("causal_mask", torch.triu(causal_mask, diagonal=1), persistent=False)
# Initialize weights and apply final processing
self.post_init()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment