Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e21c4fb
Unverified
Commit
1e21c4fb
authored
Mar 13, 2024
by
Joao Gante
Committed by
GitHub
Mar 13, 2024
Browse files
Llama: allow custom 4d masks (#29618)
parent
88a4f68f
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
41 deletions
+35
-41
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+9
-4
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+9
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+17
-33
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
1e21c4fb
...
...
@@ -975,11 +975,16 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
to
(
dtype
=
dtype
,
device
=
device
)
*
min_dtype
causal_mask
=
causal_mask
.
expand
(
batch_size
,
1
,
-
1
,
-
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
:
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
mask_shape
=
attention_mask
.
shape
mask_slice
=
(
attention_mask
.
eq
(
0.0
)).
to
(
dtype
=
dtype
)
*
min_dtype
causal_mask
[:
mask_shape
[
0
],
:
mask_shape
[
1
],
:
mask_shape
[
2
],
:
mask_shape
[
3
]]
=
mask_slice
if
(
self
.
config
.
_attn_implementation
==
"sdpa"
...
...
src/transformers/models/llama/modeling_llama.py
View file @
1e21c4fb
...
...
@@ -1083,11 +1083,16 @@ class LlamaModel(LlamaPreTrainedModel):
min_dtype
=
torch
.
finfo
(
dtype
).
min
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
to
(
dtype
=
dtype
,
device
=
device
)
*
min_dtype
causal_mask
=
causal_mask
.
expand
(
batch_size
,
1
,
-
1
,
-
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
:
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
mask_shape
=
attention_mask
.
shape
mask_slice
=
(
attention_mask
.
eq
(
0.0
)).
to
(
dtype
=
dtype
)
*
min_dtype
causal_mask
[:
mask_shape
[
0
],
:
mask_shape
[
1
],
:
mask_shape
[
2
],
:
mask_shape
[
3
]]
=
mask_slice
if
(
self
.
config
.
_attn_implementation
==
"sdpa"
...
...
tests/test_modeling_utils.py
View file @
1e21c4fb
...
...
@@ -1992,6 +1992,8 @@ class Mask4DTestBase(unittest.TestCase):
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')
position_ids_0
=
torch
.
tensor
([[
0
,
1
,
2
,
3
]]
*
3
,
device
=
torch_device
,
dtype
=
torch
.
int64
)
# Combining common prefix with the unique ending tokens:
input_1
=
torch
.
cat
([
input_0
[
0
][:
-
1
],
input_0
[:,
-
1
]]).
unsqueeze
(
0
)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
...
...
@@ -2017,81 +2019,63 @@ class Mask4DTestBase(unittest.TestCase):
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
3
,
3
]],
device
=
torch_device
,
dtype
=
torch
.
int64
)
return
input_0
,
input_1
,
mask_1
,
position_ids_1
return
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
@
slow
@
require_torch_gpu
class
Mask4DTestFP32
(
Mask4DTestBase
):
def
setUp
(
self
):
model_name
=
"JackFram/llama-68m"
# small Llama-like model from FlexFlow
model_dtype
=
torch
.
float32
self
.
model_dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
model_dtype
).
to
(
torch_device
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
self
.
model_dtype
).
to
(
torch_device
)
def
test_attention
(
self
):
"""comparing outputs of attention layer"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
causal_mask_1
=
(
1
-
mask_1
).
to
(
self
.
model_dtype
)
*
torch
.
finfo
(
self
.
model_dtype
).
min
hid_0
=
self
.
model
.
model
.
embed_tokens
(
input_0
)
outs_0
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
forward
(
hid_0
)[
0
]
outs_0
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
forward
(
hid_0
,
position_ids
=
position_ids_0
)[
0
]
# outs_0.shape == torch.Size([3, 4, 768])
hid_1
=
self
.
model
.
model
.
embed_tokens
(
input_1
)
outs_1
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
forward
(
hid_1
,
attention_mask
=
mask_1
.
bool
()
,
position_ids
=
position_ids_1
hid_1
,
attention_mask
=
causal_mask_1
,
position_ids
=
position_ids_1
)[
0
]
# outs_1.shape == torch.Size([1, 6, 768])
outs_0_last_tokens
=
outs_0
[:,
-
1
,
:]
# last tokens in each batch line
outs_1_last_tokens
=
outs_1
[
0
,
-
3
:,
:]
# last three tokens
assert
torch
.
allclose
(
outs_0_last_tokens
,
outs_1_last_tokens
)
def
test_inner_model
(
self
):
"""comparing hidden outputs of whole inner model"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
logits_0
=
self
.
model
.
forward
(
input_0
).
logits
logits_1
=
self
.
model
.
forward
(
input_1
,
attention_mask
=
mask_1
.
bool
(),
position_ids
=
position_ids_1
).
logits
logits_0_last_tokens
=
logits_0
[:,
-
1
,
:]
# last tokens in each batch line
logits_1_last_tokens
=
logits_1
[
0
,
-
3
:,
:]
# last three tokens
torch
.
testing
.
assert_close
(
logits_0_last_tokens
,
logits_1_last_tokens
,
)
torch
.
testing
.
assert_close
(
outs_0_last_tokens
,
outs_1_last_tokens
)
def
test_causal_model_logits
(
self
):
"""comparing logits outputs of whole inner model"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
logits_0
=
self
.
model
.
forward
(
input_0
).
logits
logits_0
=
self
.
model
.
forward
(
input_0
,
position_ids
=
position_ids_0
).
logits
logits_1
=
self
.
model
.
forward
(
input_1
,
attention_mask
=
mask_1
.
bool
(),
position_ids
=
position_ids_1
).
logits
logits_0_last_tokens
=
logits_0
[:,
-
1
,
:]
# last tokens in each batch line
logits_1_last_tokens
=
logits_1
[
0
,
-
3
:,
:]
# last three tokens
torch
.
testing
.
assert_close
(
logits_0_last_tokens
,
logits_1_last_tokens
,
)
torch
.
testing
.
assert_close
(
logits_0_last_tokens
,
logits_1_last_tokens
)
@
slow
@
require_torch_gpu
class
Mask4DTestFP16
(
Mask4DTestBase
):
test_attention
=
Mask4DTestFP32
.
test_attention
def
setUp
(
self
):
model_name
=
"JackFram/llama-68m"
# small Llama-like model from FlexFlow
model_dtype
=
torch
.
float16
self
.
model_dtype
=
torch
.
float16
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
model_dtype
).
to
(
torch_device
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
self
.
model_dtype
).
to
(
torch_device
)
def
test_causal_model_logits
(
self
):
"""comparing logits outputs of whole inner model"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
logits_0
=
self
.
model
.
forward
(
input_0
).
logits
logits_0
=
self
.
model
.
forward
(
input_0
,
position_ids
=
position_ids_0
).
logits
logits_1
=
self
.
model
.
forward
(
input_1
,
attention_mask
=
mask_1
.
bool
(),
position_ids
=
position_ids_1
).
logits
logits_0_last_tokens
=
logits_0
[:,
-
1
,
:]
# last tokens in each batch line
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment