Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
a9a4b4e4
Commit
a9a4b4e4
authored
May 04, 2023
by
Tri Dao
Browse files
[LLaMa] Fix last norm layer to use RMSNorm instead of LayerNorm
parent
ad113948
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
13 additions
and
9 deletions
+13
-9
flash_attn/models/gpt.py
flash_attn/models/gpt.py
+7
-2
tests/models/test_llama.py
tests/models/test_llama.py
+6
-7
No files found.
flash_attn/models/gpt.py
View file @
a9a4b4e4
...
...
@@ -377,13 +377,18 @@ class GPTModel(GPTPreTrainedModel):
else
:
# Set prenorm=False here since we don't need the residual
if
not
self
.
parallel_block
:
hidden_states
=
dropout_add_layer_norm
(
fused_add_norm_fn
=
(
dropout_add_rms_norm
if
isinstance
(
self
.
ln_f
,
RMSNorm
)
else
dropout_add_layer_norm
)
hidden_states
=
fused_add_norm_fn
(
hidden_states
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
)
else
:
hidden_states
,
_
=
dropout_add_layer_norm_parallel_residual
(
fused_add_norm_fn
=
(
dropout_add_rms_norm_parallel_residual
if
isinstance
(
self
.
ln_f
,
RMSNorm
)
else
dropout_add_layer_norm_parallel_residual
)
hidden_states
,
_
=
fused_add_norm_fn
(
hidden_states
,
hidden_states2
,
residual
,
self
.
ln_f
.
weight
,
self
.
ln_f
.
bias
,
None
,
None
,
self
.
drop_f
.
p
if
self
.
training
else
0.0
,
self
.
ln_f
.
eps
,
prenorm
=
False
,
residual_in_fp32
=
self
.
residual_in_fp32
...
...
tests/models/test_llama.py
View file @
a9a4b4e4
...
...
@@ -176,13 +176,13 @@ def test_llama_parallel(model_name, world_size):
print
(
f
'Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
out_hf
-
out_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
3
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<
2
*
(
out_hf
-
out_ref
).
abs
().
max
().
item
()
print
(
f
'Logits max diff:
{
(
logits
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits mean diff:
{
(
logits
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
print
(
f
'HF fp16 max diff:
{
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
}
'
)
print
(
f
'HF fp16 mean diff:
{
(
logits_hf
-
logits_ref
).
abs
().
mean
().
item
()
}
'
)
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
3
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
assert
(
logits
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
@
pytest
.
mark
.
parametrize
(
'model_name'
,
[
"7B"
])
...
...
@@ -267,11 +267,10 @@ def test_llama_generation(model_name):
del
model
hf_error
=
(
logits_hf
-
logits_ref
).
abs
().
max
().
item
()
# For some reason logits_parallel is off by quite a bit more than 2x
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
8
*
hf_error
assert
(
logits_parallel
-
logits_ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'HF fp16 logits max diff:
{
hf_error
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_
parallel
).
abs
().
max
().
item
()
}
'
)
assert
(
logits
-
logits_
parallel
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_
parallel
).
abs
().
max
().
item
()
}
'
)
print
(
f
'Logits max diff:
{
(
logits
-
logits_
ref
).
abs
().
max
().
item
()
}
'
)
assert
(
logits
-
logits_
ref
).
abs
().
max
().
item
()
<
2
*
hf_error
print
(
f
'Logits CG max diff:
{
(
logits_cg
-
logits_
ref
).
abs
().
max
().
item
()
}
'
)
assert
torch
.
equal
(
logits_cg
,
logits
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment