Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
1e21c4fb
Unverified
Commit
1e21c4fb
authored
Mar 13, 2024
by
Joao Gante
Committed by
GitHub
Mar 13, 2024
Browse files
Llama: allow custom 4d masks (#29618)
parent
88a4f68f
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
41 deletions
+35
-41
src/transformers/models/gemma/modeling_gemma.py
src/transformers/models/gemma/modeling_gemma.py
+9
-4
src/transformers/models/llama/modeling_llama.py
src/transformers/models/llama/modeling_llama.py
+9
-4
tests/test_modeling_utils.py
tests/test_modeling_utils.py
+17
-33
No files found.
src/transformers/models/gemma/modeling_gemma.py
View file @
1e21c4fb
...
...
@@ -975,11 +975,16 @@ class GemmaModel(GemmaPreTrainedModel):
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
to
(
dtype
=
dtype
,
device
=
device
)
*
min_dtype
causal_mask
=
causal_mask
.
expand
(
batch_size
,
1
,
-
1
,
-
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
:
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
mask_shape
=
attention_mask
.
shape
mask_slice
=
(
attention_mask
.
eq
(
0.0
)).
to
(
dtype
=
dtype
)
*
min_dtype
causal_mask
[:
mask_shape
[
0
],
:
mask_shape
[
1
],
:
mask_shape
[
2
],
:
mask_shape
[
3
]]
=
mask_slice
if
(
self
.
config
.
_attn_implementation
==
"sdpa"
...
...
src/transformers/models/llama/modeling_llama.py
View file @
1e21c4fb
...
...
@@ -1083,11 +1083,16 @@ class LlamaModel(LlamaPreTrainedModel):
min_dtype
=
torch
.
finfo
(
dtype
).
min
causal_mask
=
self
.
causal_mask
[
None
,
None
,
:,
:].
to
(
dtype
=
dtype
,
device
=
device
)
*
min_dtype
causal_mask
=
causal_mask
.
expand
(
batch_size
,
1
,
-
1
,
-
1
)
if
attention_mask
is
not
None
and
attention_mask
.
dim
()
==
2
:
if
attention_mask
is
not
None
:
causal_mask
=
causal_mask
.
clone
()
# copy to contiguous memory for in-place edit
if
attention_mask
.
dim
()
==
2
:
mask_length
=
attention_mask
.
shape
[
-
1
]
padding_mask
=
causal_mask
[...,
:
mask_length
].
eq
(
0.0
)
*
attention_mask
[:,
None
,
None
,
:].
eq
(
0.0
)
causal_mask
[...,
:
mask_length
]
=
causal_mask
[...,
:
mask_length
].
masked_fill
(
padding_mask
,
min_dtype
)
elif
attention_mask
.
dim
()
==
4
:
mask_shape
=
attention_mask
.
shape
mask_slice
=
(
attention_mask
.
eq
(
0.0
)).
to
(
dtype
=
dtype
)
*
min_dtype
causal_mask
[:
mask_shape
[
0
],
:
mask_shape
[
1
],
:
mask_shape
[
2
],
:
mask_shape
[
3
]]
=
mask_slice
if
(
self
.
config
.
_attn_implementation
==
"sdpa"
...
...
tests/test_modeling_utils.py
View file @
1e21c4fb
...
...
@@ -1992,6 +1992,8 @@ class Mask4DTestBase(unittest.TestCase):
# [ 1, 278, 6635, 750],
# [ 1, 278, 6635, 338]], device='cuda:0')
position_ids_0
=
torch
.
tensor
([[
0
,
1
,
2
,
3
]]
*
3
,
device
=
torch_device
,
dtype
=
torch
.
int64
)
# Combining common prefix with the unique ending tokens:
input_1
=
torch
.
cat
([
input_0
[
0
][:
-
1
],
input_0
[:,
-
1
]]).
unsqueeze
(
0
)
# tensor([[ 1, 278, 6635, 3290, 750, 338]], device='cuda:0')
...
...
@@ -2017,81 +2019,63 @@ class Mask4DTestBase(unittest.TestCase):
# Creating a position_ids tensor. note the repeating figures in the end.
position_ids_1
=
torch
.
tensor
([[
0
,
1
,
2
,
3
,
3
,
3
]],
device
=
torch_device
,
dtype
=
torch
.
int64
)
return
input_0
,
input_1
,
mask_1
,
position_ids_1
return
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
@
slow
@
require_torch_gpu
class
Mask4DTestFP32
(
Mask4DTestBase
):
def
setUp
(
self
):
model_name
=
"JackFram/llama-68m"
# small Llama-like model from FlexFlow
model_dtype
=
torch
.
float32
self
.
model_dtype
=
torch
.
float32
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
model_dtype
).
to
(
torch_device
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
self
.
model_dtype
).
to
(
torch_device
)
def
test_attention
(
self
):
"""comparing outputs of attention layer"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
causal_mask_1
=
(
1
-
mask_1
).
to
(
self
.
model_dtype
)
*
torch
.
finfo
(
self
.
model_dtype
).
min
hid_0
=
self
.
model
.
model
.
embed_tokens
(
input_0
)
outs_0
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
forward
(
hid_0
)[
0
]
outs_0
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
forward
(
hid_0
,
position_ids
=
position_ids_0
)[
0
]
# outs_0.shape == torch.Size([3, 4, 768])
hid_1
=
self
.
model
.
model
.
embed_tokens
(
input_1
)
outs_1
=
self
.
model
.
model
.
layers
[
0
].
self_attn
.
forward
(
hid_1
,
attention_mask
=
mask_1
.
bool
()
,
position_ids
=
position_ids_1
hid_1
,
attention_mask
=
causal_mask_1
,
position_ids
=
position_ids_1
)[
0
]
# outs_1.shape == torch.Size([1, 6, 768])
outs_0_last_tokens
=
outs_0
[:,
-
1
,
:]
# last tokens in each batch line
outs_1_last_tokens
=
outs_1
[
0
,
-
3
:,
:]
# last three tokens
assert
torch
.
allclose
(
outs_0_last_tokens
,
outs_1_last_tokens
)
def
test_inner_model
(
self
):
"""comparing hidden outputs of whole inner model"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
logits_0
=
self
.
model
.
forward
(
input_0
).
logits
logits_1
=
self
.
model
.
forward
(
input_1
,
attention_mask
=
mask_1
.
bool
(),
position_ids
=
position_ids_1
).
logits
logits_0_last_tokens
=
logits_0
[:,
-
1
,
:]
# last tokens in each batch line
logits_1_last_tokens
=
logits_1
[
0
,
-
3
:,
:]
# last three tokens
torch
.
testing
.
assert_close
(
logits_0_last_tokens
,
logits_1_last_tokens
,
)
torch
.
testing
.
assert_close
(
outs_0_last_tokens
,
outs_1_last_tokens
)
def
test_causal_model_logits
(
self
):
"""comparing logits outputs of whole inner model"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
logits_0
=
self
.
model
.
forward
(
input_0
).
logits
logits_0
=
self
.
model
.
forward
(
input_0
,
position_ids
=
position_ids_0
).
logits
logits_1
=
self
.
model
.
forward
(
input_1
,
attention_mask
=
mask_1
.
bool
(),
position_ids
=
position_ids_1
).
logits
logits_0_last_tokens
=
logits_0
[:,
-
1
,
:]
# last tokens in each batch line
logits_1_last_tokens
=
logits_1
[
0
,
-
3
:,
:]
# last three tokens
torch
.
testing
.
assert_close
(
logits_0_last_tokens
,
logits_1_last_tokens
,
)
torch
.
testing
.
assert_close
(
logits_0_last_tokens
,
logits_1_last_tokens
)
@
slow
@
require_torch_gpu
class
Mask4DTestFP16
(
Mask4DTestBase
):
test_attention
=
Mask4DTestFP32
.
test_attention
def
setUp
(
self
):
model_name
=
"JackFram/llama-68m"
# small Llama-like model from FlexFlow
model_dtype
=
torch
.
float16
self
.
model_dtype
=
torch
.
float16
self
.
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_name
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
model_dtype
).
to
(
torch_device
)
self
.
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_name
,
torch_dtype
=
self
.
model_dtype
).
to
(
torch_device
)
def
test_causal_model_logits
(
self
):
"""comparing logits outputs of whole inner model"""
input_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
input_0
,
position_ids_0
,
input_1
,
mask_1
,
position_ids_1
=
self
.
get_test_data
()
logits_0
=
self
.
model
.
forward
(
input_0
).
logits
logits_0
=
self
.
model
.
forward
(
input_0
,
position_ids
=
position_ids_0
).
logits
logits_1
=
self
.
model
.
forward
(
input_1
,
attention_mask
=
mask_1
.
bool
(),
position_ids
=
position_ids_1
).
logits
logits_0_last_tokens
=
logits_0
[:,
-
1
,
:]
# last tokens in each batch line
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment