Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1e712ea8
Commit
1e712ea8
authored
Dec 24, 2022
by
Tri Dao
Browse files
Implement TensorParallel for MHA
parent
226a1b72
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
199 additions
and
23 deletions
+199
-23
csrc/rotary/rotary.cpp
csrc/rotary/rotary.cpp
+8
-5
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+7
-5
flash_attn/modules/mha.py
flash_attn/modules/mha.py
+75
-13
tests/modules/test_mha_parallel.py
tests/modules/test_mha_parallel.py
+109
-0
No files found.
csrc/rotary/rotary.cpp
View file @
1e712ea8
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#define CHECK_DEVICE(x) \
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
#x " must have shape (" #__VA_ARGS__ ")")
void
apply_rotary_cuda
(
const
torch
::
Tensor
x1
,
const
torch
::
Tensor
x2
,
void
apply_rotary_cuda
(
const
torch
::
Tensor
x1
,
const
torch
::
Tensor
x2
,
const
torch
::
Tensor
cos
,
const
torch
::
Tensor
sin
,
const
torch
::
Tensor
cos
,
const
torch
::
Tensor
sin
,
...
@@ -26,6 +24,11 @@ void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
...
@@ -26,6 +24,11 @@ void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
TORCH_CHECK
(
x1
.
sizes
()
==
x2
.
sizes
());
TORCH_CHECK
(
x1
.
sizes
()
==
x2
.
sizes
());
TORCH_CHECK
(
cos
.
sizes
()
==
sin
.
sizes
());
TORCH_CHECK
(
cos
.
sizes
()
==
sin
.
sizes
());
TORCH_CHECK
(
out1
.
sizes
()
==
out2
.
sizes
());
TORCH_CHECK
(
out1
.
sizes
()
==
out2
.
sizes
());
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
x1
.
get_device
()};
apply_rotary_cuda
(
x1
,
x2
,
cos
,
sin
,
out1
,
out2
,
conj
);
apply_rotary_cuda
(
x1
,
x2
,
cos
,
sin
,
out1
,
out2
,
conj
);
}
}
...
...
flash_attn/layers/rotary.py
View file @
1e712ea8
...
@@ -137,17 +137,19 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -137,17 +137,19 @@ class RotaryEmbedding(torch.nn.Module):
"""
"""
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
scale_base
=
0
,
*
_
,
**
__
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
scale_base
=
0
,
device
=
None
):
"""
"""
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
"""
"""
super
().
__init__
()
super
().
__init__
()
# Generate and save the inverse frequency buffer (non trainable)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
scale_base
=
scale_base
self
.
scale_base
=
scale_base
scale
=
(
torch
.
arange
(
0
,
dim
,
2
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
>
0
else
None
scale
=
((
torch
.
arange
(
0
,
dim
,
2
,
device
=
device
,
dtype
=
torch
.
float32
)
+
0.4
*
dim
)
/
(
1.4
*
dim
)
if
scale_base
>
0
else
None
)
self
.
register_buffer
(
"scale"
,
scale
)
self
.
register_buffer
(
"scale"
,
scale
)
self
.
_seq_len_cached
=
0
self
.
_seq_len_cached
=
0
...
@@ -168,14 +170,14 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -168,14 +170,14 @@ class RotaryEmbedding(torch.nn.Module):
t
=
torch
.
arange
(
seqlen
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
t
=
torch
.
arange
(
seqlen
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
# Don't do einsum, it converts fp32 to fp16
# Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
)
freqs
=
torch
.
outer
(
t
,
self
.
inv_freq
.
to
(
device
=
t
.
device
)
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
x
.
dtype
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
x
.
dtype
)
else
:
else
:
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
power
=
((
torch
.
arange
(
seqlen
,
dtype
=
self
.
scale
.
dtype
,
device
=
self
.
scale
.
device
)
-
seqlen
//
2
)
/
self
.
scale_base
)
-
seqlen
//
2
)
/
self
.
scale_base
)
scale
=
self
.
scale
**
rearrange
(
power
,
's -> s 1'
)
scale
=
self
.
scale
.
to
(
device
=
power
.
device
)
**
rearrange
(
power
,
's -> s 1'
)
# We want the multiplication by scale to happen in fp32
# We want the multiplication by scale to happen in fp32
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_cos_cached
=
(
torch
.
cos
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
(
torch
.
sin
(
freqs
)
*
scale
).
to
(
x
.
dtype
)
...
...
flash_attn/modules/mha.py
View file @
1e712ea8
...
@@ -21,9 +21,9 @@ except ImportError:
...
@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
flash_attn_qkvpacked_func
,
flash_attn_kvpacked_func
=
None
,
None
try
:
try
:
from
flash_attn.ops.fused_dense
import
FusedDense
from
flash_attn.ops.fused_dense
import
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
except
ImportError
:
except
ImportError
:
FusedDense
=
None
FusedDense
,
ColumnParallelLinear
,
RowParallelLinear
=
None
,
None
,
None
try
:
try
:
from
flash_attn.layers.rotary
import
RotaryEmbedding
from
flash_attn.layers.rotary
import
RotaryEmbedding
...
@@ -42,7 +42,7 @@ class FlashSelfAttention(nn.Module):
...
@@ -42,7 +42,7 @@ class FlashSelfAttention(nn.Module):
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
triton
=
False
,
device
=
None
,
dtype
=
None
):
triton
=
False
):
super
().
__init__
()
super
().
__init__
()
if
attention_dropout
!=
0.0
or
not
triton
:
if
attention_dropout
!=
0.0
or
not
triton
:
assert
flash_attn_unpadded_qkvpacked_func
is
not
None
,
'FlashAttention is not installed'
assert
flash_attn_unpadded_qkvpacked_func
is
not
None
,
'FlashAttention is not installed'
...
@@ -109,7 +109,7 @@ class FlashCrossAttention(nn.Module):
...
@@ -109,7 +109,7 @@ class FlashCrossAttention(nn.Module):
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
triton
=
False
,
device
=
None
,
dtype
=
None
):
triton
=
False
):
super
().
__init__
()
super
().
__init__
()
if
attention_dropout
!=
0.0
or
not
triton
:
if
attention_dropout
!=
0.0
or
not
triton
:
assert
flash_attn_unpadded_kvpacked_func
is
not
None
,
'FlashAttention is not installed'
assert
flash_attn_unpadded_kvpacked_func
is
not
None
,
'FlashAttention is not installed'
...
@@ -181,8 +181,7 @@ class SelfAttention(nn.Module):
...
@@ -181,8 +181,7 @@ class SelfAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
device
=
None
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
causal
=
causal
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
...
@@ -228,8 +227,7 @@ class CrossAttention(nn.Module):
...
@@ -228,8 +227,7 @@ class CrossAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention
attention_dropout: The dropout rate to apply to the attention
(default: 0.0)
(default: 0.0)
"""
"""
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
,
def
__init__
(
self
,
causal
=
False
,
softmax_scale
=
None
,
attention_dropout
=
0.0
):
device
=
None
,
dtype
=
None
):
super
().
__init__
()
super
().
__init__
()
self
.
causal
=
causal
self
.
causal
=
causal
self
.
softmax_scale
=
softmax_scale
self
.
softmax_scale
=
softmax_scale
...
@@ -309,7 +307,8 @@ class MHA(nn.Module):
...
@@ -309,7 +307,8 @@ class MHA(nn.Module):
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
not
cross_attn
,
'MHA with rotary embedding does not support cross-attention yet'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
)
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
device
=
device
)
if
fused_bias_fc
and
FusedDense
is
None
:
if
fused_bias_fc
and
FusedDense
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
raise
ImportError
(
'fused_dense is not installed'
)
...
@@ -338,7 +337,7 @@ class MHA(nn.Module):
...
@@ -338,7 +337,7 @@ class MHA(nn.Module):
groups
=
2
*
embed_dim
)
groups
=
2
*
embed_dim
)
inner_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
inner_attn_cls
=
FlashCrossAttention
if
use_flash_attn
else
CrossAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
,
**
factory_kwargs
)
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
# output projection always have the bias (for now)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
self
.
out_proj
=
linear_cls
(
embed_dim
,
embed_dim
,
**
factory_kwargs
)
...
@@ -378,7 +377,7 @@ class MHA(nn.Module):
...
@@ -378,7 +377,7 @@ class MHA(nn.Module):
if
self
.
dwconv
:
if
self
.
dwconv
:
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
qkv
=
rearrange
(
self
.
dwconv_qkv
(
rearrange
(
qkv
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
h
=
self
.
num_heads
)
qkv
=
rearrange
(
qkv
,
'... (three h d) -> ... three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
if
self
.
rotary_emb_dim
>
0
:
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
qkv
=
self
.
rotary_emb
(
qkv
)
if
not
self
.
checkpointing
:
if
not
self
.
checkpointing
:
...
@@ -395,8 +394,8 @@ class MHA(nn.Module):
...
@@ -395,8 +394,8 @@ class MHA(nn.Module):
else
:
else
:
kv
,
x
=
self
.
Wkv
(
x
)
kv
,
x
=
self
.
Wkv
(
x
)
q
=
self
.
Wq
(
x
)
q
=
self
.
Wq
(
x
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
h
=
self
.
num_heads
)
q
=
rearrange
(
q
,
'... (h d) -> ... h d'
,
d
=
self
.
head_dim
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
h
=
self
.
num_heads
)
kv
=
rearrange
(
kv
,
'... (two h d) -> ... two h d'
,
two
=
2
,
d
=
self
.
head_dim
)
if
self
.
dwconv
:
if
self
.
dwconv
:
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
'b s d -> b d s'
))[...,
:
-
2
],
q
=
rearrange
(
self
.
dwconv_q
(
rearrange
(
q
,
'b s d -> b d s'
))[...,
:
-
2
],
'b d s -> b s d'
).
contiguous
()
'b d s -> b s d'
).
contiguous
()
...
@@ -408,3 +407,66 @@ class MHA(nn.Module):
...
@@ -408,3 +407,66 @@ class MHA(nn.Module):
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
q
,
kv
,
**
kwargs
)
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
out
=
self
.
out_proj
(
rearrange
(
context
,
'... h d -> ... (h d)'
))
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
return
out
if
not
self
.
return_residual
else
(
out
,
x
)
class
ParallelMHA
(
nn
.
Module
):
"""Multi-head self-attention and cross-attention
"""
def
__init__
(
self
,
embed_dim
,
num_heads
,
process_group
,
bias
=
True
,
dropout
=
0.0
,
softmax_scale
=
None
,
causal
=
False
,
rotary_emb_dim
=
0
,
rotary_emb_scale_base
=
0
,
use_flash_attn
=
False
,
checkpointing
=
False
,
device
=
None
,
dtype
=
None
)
->
None
:
factory_kwargs
=
{
'device'
:
device
,
'dtype'
:
dtype
}
super
().
__init__
()
self
.
process_group
=
process_group
self
.
embed_dim
=
embed_dim
self
.
causal
=
causal
self
.
rotary_emb_dim
=
rotary_emb_dim
self
.
use_flash_attn
=
use_flash_attn
self
.
checkpointing
=
checkpointing
self
.
num_heads
=
num_heads
assert
self
.
embed_dim
%
num_heads
==
0
,
"self.kdim must be divisible by num_heads"
self
.
head_dim
=
self
.
embed_dim
//
num_heads
if
self
.
rotary_emb_dim
>
0
:
assert
RotaryEmbedding
is
not
None
,
'rotary_emb is not installed'
self
.
rotary_emb
=
RotaryEmbedding
(
self
.
rotary_emb_dim
,
scale_base
=
rotary_emb_scale_base
,
device
=
device
)
if
ColumnParallelLinear
is
None
or
RowParallelLinear
is
None
:
raise
ImportError
(
'fused_dense is not installed'
)
self
.
Wqkv
=
ColumnParallelLinear
(
embed_dim
,
3
*
embed_dim
,
process_group
,
bias
=
bias
,
**
factory_kwargs
)
inner_attn_cls
=
FlashSelfAttention
if
use_flash_attn
else
SelfAttention
self
.
inner_attn
=
inner_attn_cls
(
causal
=
causal
,
softmax_scale
=
softmax_scale
,
attention_dropout
=
dropout
)
# output projection always have the bias (for now)
self
.
out_proj
=
RowParallelLinear
(
embed_dim
,
embed_dim
,
process_group
,
**
factory_kwargs
)
def
forward
(
self
,
x
,
seqlen
=
None
,
**
kwargs
):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv
=
self
.
Wqkv
(
x
)
if
seqlen
is
None
:
qkv
=
rearrange
(
qkv
,
'b s (three h d) -> b s three h d'
,
three
=
3
,
d
=
self
.
head_dim
)
else
:
qkv
=
rearrange
(
qkv
,
'(b s) (three h d) -> b s three h d'
,
s
=
seqlen
,
three
=
3
,
d
=
self
.
head_dim
)
if
self
.
rotary_emb_dim
>
0
:
qkv
=
self
.
rotary_emb
(
qkv
)
if
not
self
.
checkpointing
:
context
=
self
.
inner_attn
(
qkv
,
**
kwargs
)
else
:
context
=
torch
.
utils
.
checkpoint
.
checkpoint
(
self
.
inner_attn
,
qkv
,
**
kwargs
)
if
seqlen
is
None
:
context
=
rearrange
(
context
,
'b s h d -> b s (h d)'
)
else
:
context
=
rearrange
(
context
,
'b s h d -> (b s) (h d)'
)
out
=
self
.
out_proj
(
context
)
return
out
tests/modules/test_mha_parallel.py
0 → 100644
View file @
1e712ea8
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py
import
math
import
torch
import
torch.nn.functional
as
F
import
pytest
from
einops
import
rearrange
from
apex.transformer
import
parallel_state
from
apex.transformer
import
tensor_parallel
from
flash_attn.modules.mha
import
MHA
,
ParallelMHA
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
'cuda'
)[
0
]
>=
8
@
pytest
.
mark
.
parametrize
(
'dtype'
,
[
torch
.
float16
]
+
([
torch
.
bfloat16
]
if
is_sm8x
else
[]))
# @pytest.mark.parametrize('dtype', [torch.float16])
@
pytest
.
mark
.
parametrize
(
'world_size'
,
[
1
,
2
,
4
,
8
])
# @pytest.mark.parametrize('world_size', [2])
@
pytest
.
mark
.
parametrize
(
'head_dim'
,
[
64
,
128
])
# @pytest.mark.parametrize('head_dim', [64])
@
pytest
.
mark
.
parametrize
(
'embed_dim'
,
[
1024
,
4096
])
# @pytest.mark.parametrize('embed_dim', [1024])
def
test_mha_parallel
(
embed_dim
,
head_dim
,
world_size
,
dtype
):
assert
embed_dim
%
head_dim
==
0
num_heads
=
embed_dim
//
head_dim
assert
num_heads
%
world_size
==
0
rtol
,
atol
=
(
3e-3
,
1e-2
)
if
dtype
==
torch
.
bfloat16
else
(
3e-3
,
1e-3
)
if
not
torch
.
distributed
.
is_initialized
():
torch
.
distributed
.
init_process_group
(
backend
=
'nccl'
,
init_method
=
'env://'
)
device
=
f
'cuda:
{
torch
.
distributed
.
get_rank
()
}
'
assert
world_size
<=
torch
.
distributed
.
get_world_size
()
parallel_state
.
initialize_model_parallel
(
tensor_model_parallel_size_
=
world_size
)
rank
=
parallel_state
.
get_tensor_model_parallel_rank
()
# set seed
torch
.
random
.
manual_seed
(
0
)
batch_size
=
8
seqlen
=
1024
assert
(
batch_size
*
seqlen
)
%
world_size
==
0
x_pt
=
torch
.
randn
(
batch_size
*
seqlen
,
embed_dim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g
=
torch
.
randn_like
(
x_pt
)
/
32
x
=
tensor_parallel
.
scatter_to_sequence_parallel_region
(
x_pt
).
detach
().
clone
().
requires_grad_
()
model_pt
=
MHA
(
embed_dim
,
num_heads
,
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
partition_dim
=
embed_dim
//
world_size
model
=
ParallelMHA
(
embed_dim
,
num_heads
,
parallel_state
.
get_tensor_model_parallel_group
(),
rotary_emb_dim
=
int
(
head_dim
//
2
),
use_flash_attn
=
True
,
device
=
device
,
dtype
=
dtype
)
with
torch
.
no_grad
():
model
.
Wqkv
.
weight
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
weight
,
'(three o) i -> three o i'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o i -> (three o) i'
)
)
model
.
Wqkv
.
bias
.
copy_
(
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
bias
,
'(three o) -> three o'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o -> (three o)'
)
)
model
.
out_proj
.
weight
.
copy_
(
model_pt
.
out_proj
.
weight
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
]
)
if
rank
==
0
:
model
.
out_proj
.
bias
.
copy_
(
model_pt
.
out_proj
.
bias
)
out
=
model
(
x
,
seqlen
=
seqlen
)
out_pt
=
rearrange
(
model_pt
(
rearrange
(
x_pt
,
'(b s) d -> b s d'
,
s
=
seqlen
)),
'b s d -> (b s) d'
)
partition_batch_dim
=
batch_size
*
seqlen
//
world_size
assert
torch
.
allclose
(
out
,
out_pt
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
out_pt
.
backward
(
g
)
out
.
backward
(
g
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
])
parallel_state
.
destroy_model_parallel
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
[
rank
*
partition_batch_dim
:(
rank
+
1
)
*
partition_batch_dim
],
rtol
=
rtol
,
atol
=
atol
)
# The error for d_weight and d_bias is quite a bit higher
assert
torch
.
allclose
(
model
.
Wqkv
.
weight
.
grad
,
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
weight
.
grad
,
'(three o) i -> three o i'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o i -> (three o) i'
),
rtol
=
rtol
,
atol
=
atol
*
10
)
assert
torch
.
allclose
(
model
.
Wqkv
.
bias
.
grad
,
rearrange
(
rearrange
(
model_pt
.
Wqkv
.
bias
.
grad
,
'(three o) -> three o'
,
three
=
3
)[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
'three o -> (three o)'
),
rtol
=
rtol
,
atol
=
atol
*
5
)
assert
torch
.
allclose
(
model
.
out_proj
.
weight
.
grad
,
model_pt
.
out_proj
.
weight
.
grad
[:,
rank
*
partition_dim
:(
rank
+
1
)
*
partition_dim
],
rtol
=
rtol
,
atol
=
atol
*
10
)
if
rank
==
0
:
assert
torch
.
allclose
(
model
.
out_proj
.
bias
.
grad
,
model_pt
.
out_proj
.
bias
.
grad
,
rtol
=
rtol
,
atol
=
atol
*
5
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment