Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
24d59c79
Unverified
Commit
24d59c79
authored
Feb 26, 2024
by
fxmarty
Committed by
GitHub
Feb 26, 2024
Browse files
Use `torch.bool` instead of `torch.int64` for non-persistant causal mask buffer (#29241)
use torch.bool instead of torch.int64
parent
7c4995f9
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
5 deletions
+13
-5
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+5
-2
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+8
-3
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
24d59c79
...
...
@@ -810,8 +810,11 @@ class GemmaModel(GemmaPreTrainedModel):
self
.
norm
=
GemmaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
causal_mask
=
torch
.
full
((
config
.
max_position_embeddings
,
config
.
max_position_embeddings
),
fill_value
=
1
)
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask
=
torch
.
full
(
(
config
.
max_position_embeddings
,
config
.
max_position_embeddings
),
fill_value
=
True
,
dtype
=
torch
.
bool
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
...
...
src/transformers/models/llama/modeling_llama.py
View file @
24d59c79
...
...
@@ -811,7 +811,9 @@ class LlamaPreTrainedModel(PreTrainedModel):
)
if
max_cache_len
>
self
.
model
.
causal_mask
.
shape
[
-
1
]
or
self
.
device
!=
self
.
model
.
causal_mask
.
device
:
causal_mask
=
torch
.
full
((
max_cache_len
,
max_cache_len
),
fill_value
=
1
,
device
=
self
.
device
)
causal_mask
=
torch
.
full
(
(
max_cache_len
,
max_cache_len
),
fill_value
=
True
,
device
=
self
.
device
,
dtype
=
torch
.
bool
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
for
layer
in
self
.
model
.
layers
:
...
...
@@ -919,8 +921,11 @@ class LlamaModel(LlamaPreTrainedModel):
self
.
norm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# register a causal mask to separate causal and padding mask creation. Merging happends in the attention class
causal_mask
=
torch
.
full
((
config
.
max_position_embeddings
,
config
.
max_position_embeddings
),
fill_value
=
1
)
# Register a causal mask to separate causal and padding mask creation. Merging happens in the attention class.
# NOTE: This is not friendly with TorchScript, ONNX, ExportedProgram serialization for very large `max_position_embeddings`.
causal_mask
=
torch
.
full
(
(
config
.
max_position_embeddings
,
config
.
max_position_embeddings
),
fill_value
=
True
,
dtype
=
torch
.
bool
)
self
.
register_buffer
(
"causal_mask"
,
torch
.
triu
(
causal_mask
,
diagonal
=
1
),
persistent
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment