Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
Yuan2.0-M32_pytorch
Commits
d3dd8642
Commit
d3dd8642
authored
Jun 26, 2024
by
Rayyyyy
Browse files
First add
parents
Pipeline
#1259
failed with stages
in 0 seconds
Changes
315
Pipelines
1
Hide whitespace changes
Inline
Side-by-side
Showing
15 changed files
with
3970 additions
and
0 deletions
+3970
-0
megatron/model/vision/mit_backbone.py
megatron/model/vision/mit_backbone.py
+420
-0
megatron/model/vision/swin_backbone.py
megatron/model/vision/swin_backbone.py
+625
-0
megatron/model/vision/utils.py
megatron/model/vision/utils.py
+27
-0
megatron/model/vision/vit_backbone.py
megatron/model/vision/vit_backbone.py
+245
-0
megatron/model/yuan_hf_model.py
megatron/model/yuan_hf_model.py
+1222
-0
megatron/model/yuan_model.py
megatron/model/yuan_model.py
+122
-0
megatron/mpu/tests/__init__.py
megatron/mpu/tests/__init__.py
+0
-0
megatron/mpu/tests/commons.py
megatron/mpu/tests/commons.py
+70
-0
megatron/mpu/tests/test_cross_entropy.py
megatron/mpu/tests/test_cross_entropy.py
+95
-0
megatron/mpu/tests/test_data.py
megatron/mpu/tests/test_data.py
+75
-0
megatron/mpu/tests/test_initialize.py
megatron/mpu/tests/test_initialize.py
+82
-0
megatron/mpu/tests/test_layers.py
megatron/mpu/tests/test_layers.py
+517
-0
megatron/mpu/tests/test_random.py
megatron/mpu/tests/test_random.py
+191
-0
megatron/optimizer/__init__.py
megatron/optimizer/__init__.py
+144
-0
megatron/optimizer/clip_grads.py
megatron/optimizer/clip_grads.py
+135
-0
No files found.
Too many changes to show.
To preserve performance only
315 of 315+
files are displayed.
Plain diff
Email patch
megatron/model/vision/mit_backbone.py
0 → 100644
View file @
d3dd8642
# ---------------------------------------------------------------
# Copyright (c) 2021, NVIDIA Corporation. All rights reserved.
#
# This work is licensed under the NVIDIA Source Code License
# found in the LICENSE file in the root directory of this
# source tree.
# ---------------------------------------------------------------
import
math
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
functools
import
partial
from
torch.nn.init
import
trunc_normal_
from
megatron.model.transformer
import
DropPath
from
megatron.model
import
LayerNorm
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
dwconv
=
DWConv
(
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
self
.
fc1
(
x
)
x
=
self
.
dwconv
(
x
,
H
,
W
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
class
Attention
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
=
8
,
qkv_bias
=
False
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
,
sr_ratio
=
1
):
super
().
__init__
()
assert
dim
%
num_heads
==
0
,
f
"dim
{
dim
}
should be divided by num_heads
{
num_heads
}
."
self
.
dim
=
dim
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
self
.
q
=
nn
.
Linear
(
dim
,
dim
,
bias
=
qkv_bias
)
self
.
kv
=
nn
.
Linear
(
dim
,
dim
*
2
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
self
.
sr_ratio
=
sr_ratio
if
sr_ratio
>
1
:
self
.
sr
=
nn
.
Conv2d
(
dim
,
dim
,
kernel_size
=
sr_ratio
,
stride
=
sr_ratio
)
self
.
norm
=
LayerNorm
(
dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
q
=
self
.
q
(
x
).
reshape
(
B
,
N
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
0
,
2
,
1
,
3
)
if
self
.
sr_ratio
>
1
:
x_
=
x
.
permute
(
0
,
2
,
1
).
reshape
(
B
,
C
,
H
,
W
)
x_
=
self
.
sr
(
x_
).
reshape
(
B
,
C
,
-
1
).
permute
(
0
,
2
,
1
)
x_
=
self
.
norm
(
x_
)
kv
=
self
.
kv
(
x_
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
else
:
kv
=
self
.
kv
(
x
).
reshape
(
B
,
-
1
,
2
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
k
,
v
=
kv
[
0
],
kv
[
1
]
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
*
self
.
scale
attn
=
attn
.
softmax
(
dim
=-
1
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
class
Block
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
num_heads
,
mlp_ratio
=
4.
,
qkv_bias
=
False
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
LayerNorm
,
sr_ratio
=
1
):
super
().
__init__
()
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
Attention
(
dim
,
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
,
sr_ratio
=
sr_ratio
)
# NOTE: drop path for stochastic depth, we shall see if this is better than dropout here
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
,
H
,
W
):
x
=
x
+
self
.
drop_path
(
self
.
attn
(
self
.
norm1
(
x
),
H
,
W
))
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
),
H
,
W
))
return
x
class
OverlapPatchEmbed
(
nn
.
Module
):
""" Image to Patch Embedding
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
3
,
embed_dim
=
768
):
super
().
__init__
()
img_size
=
(
img_size
,
img_size
)
patch_size
=
(
patch_size
,
patch_size
)
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
stride
,
padding
=
(
patch_size
[
0
]
//
2
,
patch_size
[
1
]
//
2
))
self
.
norm
=
LayerNorm
(
embed_dim
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
forward
(
self
,
x
):
x
=
self
.
proj
(
x
)
_
,
_
,
H
,
W
=
x
.
shape
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
x
=
self
.
norm
(
x
)
return
x
,
H
,
W
class
MixVisionTransformer
(
nn
.
Module
):
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
16
,
in_chans
=
3
,
num_classes
=
1000
,
embed_dims
=
[
64
,
128
,
256
,
512
],
num_heads
=
[
1
,
2
,
4
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
False
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.
,
norm_layer
=
LayerNorm
,
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
output_avg
=
False
):
super
().
__init__
()
self
.
num_classes
=
num_classes
self
.
depths
=
depths
self
.
output_avg
=
output_avg
# patch_embed
self
.
patch_embed1
=
OverlapPatchEmbed
(
img_size
=
img_size
,
patch_size
=
7
,
stride
=
4
,
in_chans
=
in_chans
,
embed_dim
=
embed_dims
[
0
])
self
.
patch_embed2
=
OverlapPatchEmbed
(
img_size
=
img_size
//
4
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
0
],
embed_dim
=
embed_dims
[
1
])
self
.
patch_embed3
=
OverlapPatchEmbed
(
img_size
=
img_size
//
8
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
1
],
embed_dim
=
embed_dims
[
2
])
self
.
patch_embed4
=
OverlapPatchEmbed
(
img_size
=
img_size
//
16
,
patch_size
=
3
,
stride
=
2
,
in_chans
=
embed_dims
[
2
],
embed_dim
=
embed_dims
[
3
])
# transformer encoder
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
cur
=
0
self
.
block1
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
0
],
num_heads
=
num_heads
[
0
],
mlp_ratio
=
mlp_ratios
[
0
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
0
])
for
i
in
range
(
depths
[
0
])])
self
.
norm1
=
norm_layer
(
embed_dims
[
0
])
cur
+=
depths
[
0
]
self
.
block2
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
1
],
num_heads
=
num_heads
[
1
],
mlp_ratio
=
mlp_ratios
[
1
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
1
])
for
i
in
range
(
depths
[
1
])])
self
.
norm2
=
norm_layer
(
embed_dims
[
1
])
cur
+=
depths
[
1
]
self
.
block3
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
2
],
num_heads
=
num_heads
[
2
],
mlp_ratio
=
mlp_ratios
[
2
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
2
])
for
i
in
range
(
depths
[
2
])])
self
.
norm3
=
norm_layer
(
embed_dims
[
2
])
cur
+=
depths
[
2
]
self
.
block4
=
nn
.
ModuleList
([
Block
(
dim
=
embed_dims
[
3
],
num_heads
=
num_heads
[
3
],
mlp_ratio
=
mlp_ratios
[
3
],
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
cur
+
i
],
norm_layer
=
norm_layer
,
sr_ratio
=
sr_ratios
[
3
])
for
i
in
range
(
depths
[
3
])])
self
.
norm4
=
norm_layer
(
embed_dims
[
3
])
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
elif
isinstance
(
m
,
nn
.
Conv2d
):
fan_out
=
m
.
kernel_size
[
0
]
*
m
.
kernel_size
[
1
]
*
m
.
out_channels
fan_out
//=
m
.
groups
m
.
weight
.
data
.
normal_
(
0
,
math
.
sqrt
(
2.0
/
fan_out
))
if
m
.
bias
is
not
None
:
m
.
bias
.
data
.
zero_
()
def
reset_drop_path
(
self
,
drop_path_rate
):
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
self
.
depths
))]
cur
=
0
for
i
in
range
(
self
.
depths
[
0
]):
self
.
block1
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
0
]
for
i
in
range
(
self
.
depths
[
1
]):
self
.
block2
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
1
]
for
i
in
range
(
self
.
depths
[
2
]):
self
.
block3
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
cur
+=
self
.
depths
[
2
]
for
i
in
range
(
self
.
depths
[
3
]):
self
.
block4
[
i
].
drop_path
.
drop_prob
=
dpr
[
cur
+
i
]
def
freeze_patch_emb
(
self
):
self
.
patch_embed1
.
requires_grad
=
False
def
forward_features
(
self
,
x
):
B
=
x
.
shape
[
0
]
outs
=
[]
# stage 1
x
,
H
,
W
=
self
.
patch_embed1
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block1
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm1
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 2
x
,
H
,
W
=
self
.
patch_embed2
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block2
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm2
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 3
x
,
H
,
W
=
self
.
patch_embed3
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block3
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm3
(
x
)
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
# stage 4
x
,
H
,
W
=
self
.
patch_embed4
(
x
)
for
i
,
blk
in
enumerate
(
self
.
block4
):
x
=
blk
(
x
,
H
,
W
)
x
=
self
.
norm4
(
x
)
if
not
self
.
output_avg
:
x
=
x
.
reshape
(
B
,
H
,
W
,
-
1
).
permute
(
0
,
3
,
1
,
2
).
contiguous
()
outs
.
append
(
x
)
return
outs
def
forward
(
self
,
x
):
x
=
self
.
forward_features
(
x
)
if
self
.
output_avg
:
x
=
x
[
3
].
mean
(
dim
=
1
)
return
x
class
DWConv
(
nn
.
Module
):
def
__init__
(
self
,
dim
=
768
):
super
(
DWConv
,
self
).
__init__
()
self
.
dwconv
=
nn
.
Conv2d
(
dim
,
dim
,
3
,
1
,
1
,
bias
=
True
,
groups
=
dim
)
def
forward
(
self
,
x
,
H
,
W
):
B
,
N
,
C
=
x
.
shape
x
=
x
.
transpose
(
1
,
2
).
view
(
B
,
C
,
H
,
W
)
x
=
self
.
dwconv
(
x
)
x
=
x
.
flatten
(
2
).
transpose
(
1
,
2
)
return
x
class
mit_b0
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b0
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
32
,
64
,
160
,
256
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b1
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b1
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
2
,
2
,
2
,
2
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b2
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b2
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
6
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b3
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b3_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b3_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
4
,
18
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
class
mit_b4
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b4
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
8
,
27
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5
(
MixVisionTransformer
):
def
__init__
(
self
,
**
kwargs
):
super
(
mit_b5
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
0.1
)
class
mit_b5_avg
(
MixVisionTransformer
):
def
__init__
(
self
,
drop_path_rate
=
0.1
,
**
kwargs
):
super
(
mit_b5_avg
,
self
).
__init__
(
patch_size
=
4
,
embed_dims
=
[
64
,
128
,
320
,
512
],
num_heads
=
[
1
,
2
,
5
,
8
],
mlp_ratios
=
[
4
,
4
,
4
,
4
],
qkv_bias
=
True
,
norm_layer
=
partial
(
LayerNorm
,
eps
=
1e-6
),
depths
=
[
3
,
6
,
40
,
3
],
sr_ratios
=
[
8
,
4
,
2
,
1
],
drop_rate
=
0.0
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
True
)
megatron/model/vision/swin_backbone.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2021 Microsoft
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# Swin Transformer
# --------------------------------------------------------
import
torch
import
torch.nn
as
nn
import
torch.utils.checkpoint
as
checkpoint
from
timm.models.layers
import
DropPath
,
to_2tuple
,
trunc_normal_
from
math
import
sqrt
from
megatron
import
get_args
from
functools
import
partial
class
Mlp
(
nn
.
Module
):
def
__init__
(
self
,
in_features
,
hidden_features
=
None
,
out_features
=
None
,
act_layer
=
nn
.
GELU
,
drop
=
0.
):
super
().
__init__
()
out_features
=
out_features
or
in_features
hidden_features
=
hidden_features
or
in_features
self
.
fc1
=
nn
.
Linear
(
in_features
,
hidden_features
)
self
.
act
=
act_layer
()
self
.
fc2
=
nn
.
Linear
(
hidden_features
,
out_features
)
self
.
drop
=
nn
.
Dropout
(
drop
)
def
forward
(
self
,
x
):
x
=
self
.
fc1
(
x
)
x
=
self
.
act
(
x
)
x
=
self
.
drop
(
x
)
x
=
self
.
fc2
(
x
)
x
=
self
.
drop
(
x
)
return
x
def
window_partition
(
x
,
window_size
):
"""
Args:
x: (B, H, W, C)
window_size (int): window size
Returns:
windows: (num_windows*B, window_size, window_size, C)
"""
B
,
H
,
W
,
C
=
x
.
shape
x
=
x
.
view
(
B
,
H
//
window_size
,
window_size
,
W
//
window_size
,
window_size
,
C
)
windows
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
-
1
,
window_size
,
window_size
,
C
)
return
windows
def
window_reverse
(
windows
,
window_size
,
H
,
W
):
"""
Args:
windows: (num_windows*B, window_size, window_size, C)
window_size (int): Window size
H (int): Height of image
W (int): Width of image
Returns:
x: (B, H, W, C)
"""
B
=
int
(
windows
.
shape
[
0
]
/
(
H
*
W
/
window_size
/
window_size
))
x
=
windows
.
view
(
B
,
H
//
window_size
,
W
//
window_size
,
window_size
,
window_size
,
-
1
)
x
=
x
.
permute
(
0
,
1
,
3
,
2
,
4
,
5
).
contiguous
().
view
(
B
,
H
,
W
,
-
1
)
return
x
class
WindowAttention
(
nn
.
Module
):
r
""" Window based multi-head self attention (W-MSA) module with relative position bias.
It supports both of shifted and non-shifted window.
Args:
dim (int): Number of input channels.
window_size (tuple[int]): The height and width of the window.
num_heads (int): Number of attention heads.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set
attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
proj_drop (float, optional): Dropout ratio of output. Default: 0.0
"""
def
__init__
(
self
,
dim
,
window_size
,
num_heads
,
qkv_bias
=
True
,
qk_scale
=
None
,
attn_drop
=
0.
,
proj_drop
=
0.
):
super
().
__init__
()
self
.
dim
=
dim
self
.
window_size
=
window_size
# Wh, Ww
self
.
num_heads
=
num_heads
head_dim
=
dim
//
num_heads
self
.
scale
=
qk_scale
or
head_dim
**
-
0.5
# define a parameter table of relative position bias
self
.
relative_position_bias_table
=
nn
.
Parameter
(
torch
.
zeros
((
2
*
window_size
[
0
]
-
1
)
*
(
2
*
window_size
[
1
]
-
1
),
num_heads
))
# 2*Wh-1 * 2*Ww-1, nH
# get pair-wise relative position index for each token inside the window
coords_h
=
torch
.
arange
(
self
.
window_size
[
0
])
coords_w
=
torch
.
arange
(
self
.
window_size
[
1
])
coords
=
torch
.
stack
(
torch
.
meshgrid
([
coords_h
,
coords_w
]))
# 2, Wh, Ww
coords_flatten
=
torch
.
flatten
(
coords
,
1
)
# 2, Wh*Ww
relative_coords
=
coords_flatten
[:,
:,
None
]
-
coords_flatten
[:,
None
,
:]
# 2, Wh*Ww, Wh*Ww
relative_coords
=
relative_coords
.
permute
(
1
,
2
,
0
).
contiguous
()
# Wh*Ww, Wh*Ww, 2
relative_coords
[:,
:,
0
]
+=
self
.
window_size
[
0
]
-
1
# shift to start from 0
relative_coords
[:,
:,
1
]
+=
self
.
window_size
[
1
]
-
1
relative_coords
[:,
:,
0
]
*=
2
*
self
.
window_size
[
1
]
-
1
relative_position_index
=
relative_coords
.
sum
(
-
1
)
# Wh*Ww, Wh*Ww
self
.
register_buffer
(
"relative_position_index"
,
relative_position_index
)
self
.
qkv
=
nn
.
Linear
(
dim
,
dim
*
3
,
bias
=
qkv_bias
)
self
.
attn_drop
=
nn
.
Dropout
(
attn_drop
)
self
.
proj
=
nn
.
Linear
(
dim
,
dim
)
self
.
proj_drop
=
nn
.
Dropout
(
proj_drop
)
trunc_normal_
(
self
.
relative_position_bias_table
,
std
=
.
02
)
self
.
softmax
=
nn
.
Softmax
(
dim
=-
1
)
def
forward
(
self
,
x
,
mask
=
None
):
"""
Args:
x: input features with shape of (num_windows*B, N, C)
mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
"""
B_
,
N
,
C
=
x
.
shape
qkv
=
self
.
qkv
(
x
).
reshape
(
B_
,
N
,
3
,
self
.
num_heads
,
C
//
self
.
num_heads
).
permute
(
2
,
0
,
3
,
1
,
4
)
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
# make torchscript happy (cannot use tensor as tuple)
q
=
q
*
self
.
scale
attn
=
(
q
@
k
.
transpose
(
-
2
,
-
1
))
relative_position_bias
=
self
.
relative_position_bias_table
[
self
.
relative_position_index
.
view
(
-
1
)].
view
(
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
self
.
window_size
[
0
]
*
self
.
window_size
[
1
],
-
1
)
# Wh*Ww,Wh*Ww,nH
relative_position_bias
=
relative_position_bias
.
permute
(
2
,
0
,
1
).
contiguous
()
# nH, Wh*Ww, Wh*Ww
attn
=
attn
+
relative_position_bias
.
unsqueeze
(
0
)
if
mask
is
not
None
:
nW
=
mask
.
shape
[
0
]
attn
=
attn
.
view
(
B_
//
nW
,
nW
,
self
.
num_heads
,
N
,
N
)
+
mask
.
unsqueeze
(
1
).
unsqueeze
(
0
)
attn
=
attn
.
view
(
-
1
,
self
.
num_heads
,
N
,
N
)
attn
=
self
.
softmax
(
attn
)
else
:
attn
=
self
.
softmax
(
attn
)
attn
=
self
.
attn_drop
(
attn
)
x
=
(
attn
@
v
).
transpose
(
1
,
2
).
reshape
(
B_
,
N
,
C
)
x
=
self
.
proj
(
x
)
x
=
self
.
proj_drop
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
'dim=
{
self
.
dim
}
, window_size=
{
self
.
window_size
}
, num_heads=
{
self
.
num_heads
}
'
def
flops
(
self
,
N
):
# calculate flops for 1 window with token length of N
flops
=
0
# qkv = self.qkv(x)
flops
+=
N
*
self
.
dim
*
3
*
self
.
dim
# attn = (q @ k.transpose(-2, -1))
flops
+=
self
.
num_heads
*
N
*
(
self
.
dim
//
self
.
num_heads
)
*
N
# x = (attn @ v)
flops
+=
self
.
num_heads
*
N
*
N
*
(
self
.
dim
//
self
.
num_heads
)
# x = self.proj(x)
flops
+=
N
*
self
.
dim
*
self
.
dim
return
flops
class
SwinTransformerBlock
(
nn
.
Module
):
r
""" Swin Transformer Block.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resulotion.
num_heads (int): Number of attention heads.
window_size (int): Window size.
shift_size (int): Shift size for SW-MSA.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float, optional): Stochastic depth rate. Default: 0.0
act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
dim
,
input_resolution
,
num_heads
,
window_size
=
7
,
shift_size
=
0
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
act_layer
=
nn
.
GELU
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
num_heads
=
num_heads
self
.
window_size
=
window_size
self
.
shift_size
=
shift_size
self
.
mlp_ratio
=
mlp_ratio
if
min
(
self
.
input_resolution
)
<=
self
.
window_size
:
# if window size is larger than input resolution, we don't partition windows
self
.
shift_size
=
0
self
.
window_size
=
min
(
self
.
input_resolution
)
assert
0
<=
self
.
shift_size
<
self
.
window_size
,
"shift_size must in 0-window_size"
self
.
norm1
=
norm_layer
(
dim
)
self
.
attn
=
WindowAttention
(
dim
,
window_size
=
to_2tuple
(
self
.
window_size
),
num_heads
=
num_heads
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
attn_drop
=
attn_drop
,
proj_drop
=
drop
)
self
.
drop_path
=
DropPath
(
drop_path
)
if
drop_path
>
0.
else
nn
.
Identity
()
self
.
norm2
=
norm_layer
(
dim
)
mlp_hidden_dim
=
int
(
dim
*
mlp_ratio
)
self
.
mlp
=
Mlp
(
in_features
=
dim
,
hidden_features
=
mlp_hidden_dim
,
act_layer
=
act_layer
,
drop
=
drop
)
self
.
H
=
input_resolution
[
0
]
self
.
W
=
input_resolution
[
1
]
self
.
attn_mask_dict
=
{}
def
create_attn_mask
(
self
,
H
,
W
):
# calculate attention mask for SW-MSA
Hp
=
int
(
np
.
ceil
(
H
/
self
.
window_size
))
*
self
.
window_size
Wp
=
int
(
np
.
ceil
(
W
/
self
.
window_size
))
*
self
.
window_size
img_mask
=
torch
.
zeros
((
1
,
Hp
,
Wp
,
1
))
# 1 Hp Wp 1
h_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
w_slices
=
(
slice
(
0
,
-
self
.
window_size
),
slice
(
-
self
.
window_size
,
-
self
.
shift_size
),
slice
(
-
self
.
shift_size
,
None
))
cnt
=
0
for
h
in
h_slices
:
for
w
in
w_slices
:
img_mask
[:,
h
,
w
,
:]
=
cnt
cnt
+=
1
mask_windows
=
window_partition
(
img_mask
,
self
.
window_size
)
# nW, window_size, window_size, 1
mask_windows
=
mask_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
)
attn_mask
=
mask_windows
.
unsqueeze
(
1
)
-
mask_windows
.
unsqueeze
(
2
)
attn_mask
=
attn_mask
.
masked_fill
(
attn_mask
!=
0
,
float
(
-
100.0
)).
masked_fill
(
attn_mask
==
0
,
float
(
0.0
))
return
attn_mask
def
forward
(
self
,
x
):
B
,
L
,
C
=
x
.
shape
H
=
int
(
sqrt
(
L
))
W
=
H
shortcut
=
x
x
=
self
.
norm1
(
x
)
x
=
x
.
view
(
B
,
H
,
W
,
C
)
# cyclic shift
if
self
.
shift_size
>
0
:
shifted_x
=
torch
.
roll
(
x
,
shifts
=
(
-
self
.
shift_size
,
-
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
shifted_x
=
x
# partition windows
x_windows
=
window_partition
(
shifted_x
,
self
.
window_size
)
# nW*B, window_size, window_size, C
x_windows
=
x_windows
.
view
(
-
1
,
self
.
window_size
*
self
.
window_size
,
C
)
# nW*B, window_size*window_size, C
# W-MSA/SW-MSA
attn_windows
=
self
.
attn
(
x_windows
,
mask
=
self
.
attn_mask
)
# nW*B, window_size*window_size, C
# merge windows
attn_windows
=
attn_windows
.
view
(
-
1
,
self
.
window_size
,
self
.
window_size
,
C
)
shifted_x
=
window_reverse
(
attn_windows
,
self
.
window_size
,
H
,
W
)
# B H' W' C
# reverse cyclic shift
if
self
.
shift_size
>
0
:
x
=
torch
.
roll
(
shifted_x
,
shifts
=
(
self
.
shift_size
,
self
.
shift_size
),
dims
=
(
1
,
2
))
else
:
x
=
shifted_x
x
=
x
.
view
(
B
,
H
*
W
,
C
)
# FFN
x
=
shortcut
+
self
.
drop_path
(
x
)
x
=
x
+
self
.
drop_path
(
self
.
mlp
(
self
.
norm2
(
x
)))
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, num_heads=
{
self
.
num_heads
}
, "
\
f
"window_size=
{
self
.
window_size
}
, shift_size=
{
self
.
shift_size
}
, mlp_ratio=
{
self
.
mlp_ratio
}
"
def
flops
(
self
):
flops
=
0
H
,
W
=
self
.
input_resolution
# norm1
flops
+=
self
.
dim
*
H
*
W
# W-MSA/SW-MSA
nW
=
H
*
W
/
self
.
window_size
/
self
.
window_size
flops
+=
nW
*
self
.
attn
.
flops
(
self
.
window_size
*
self
.
window_size
)
# mlp
flops
+=
2
*
H
*
W
*
self
.
dim
*
self
.
dim
*
self
.
mlp_ratio
# norm2
flops
+=
self
.
dim
*
H
*
W
return
flops
class
PatchMerging
(
nn
.
Module
):
r
""" Patch Merging Layer.
Args:
input_resolution (tuple[int]): Resolution of input feature.
dim (int): Number of input channels.
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
"""
def
__init__
(
self
,
input_resolution
,
dim
,
norm_layer
=
nn
.
LayerNorm
):
super
().
__init__
()
self
.
input_resolution
=
input_resolution
self
.
dim
=
dim
self
.
reduction
=
nn
.
Linear
(
4
*
dim
,
2
*
dim
,
bias
=
False
)
self
.
norm
=
norm_layer
(
4
*
dim
)
def
forward
(
self
,
x
):
"""
x: B, H*W, C
"""
H
,
W
=
self
.
input_resolution
B
,
L
,
C
=
x
.
shape
assert
L
==
H
*
W
,
"input feature has wrong size"
assert
H
%
2
==
0
and
W
%
2
==
0
,
f
"x size (
{
H
}
*
{
W
}
) are not even."
x
=
x
.
view
(
B
,
H
,
W
,
C
)
x0
=
x
[:,
0
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x1
=
x
[:,
1
::
2
,
0
::
2
,
:]
# B H/2 W/2 C
x2
=
x
[:,
0
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x3
=
x
[:,
1
::
2
,
1
::
2
,
:]
# B H/2 W/2 C
x
=
torch
.
cat
([
x0
,
x1
,
x2
,
x3
],
-
1
)
# B H/2 W/2 4*C
x
=
x
.
view
(
B
,
-
1
,
4
*
C
)
# B H/2*W/2 4*C
x
=
self
.
norm
(
x
)
x
=
self
.
reduction
(
x
)
return
x
def
extra_repr
(
self
)
->
str
:
return
f
"input_resolution=
{
self
.
input_resolution
}
, dim=
{
self
.
dim
}
"
def
flops
(
self
):
H
,
W
=
self
.
input_resolution
flops
=
H
*
W
*
self
.
dim
flops
+=
(
H
//
2
)
*
(
W
//
2
)
*
4
*
self
.
dim
*
2
*
self
.
dim
return
flops
class
BasicLayer
(
nn
.
Module
):
""" A basic Swin Transformer layer for one stage.
Args:
dim (int): Number of input channels.
input_resolution (tuple[int]): Input resolution.
depth (int): Number of blocks.
num_heads (int): Number of attention heads.
window_size (int): Local window size.
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float | None, optional): Override default qk scale of head_dim ** -0.5 if set.
drop (float, optional): Dropout rate. Default: 0.0
attn_drop (float, optional): Attention dropout rate. Default: 0.0
drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
"""
def
__init__
(
self
,
dim
,
input_resolution
,
depth
,
num_heads
,
window_size
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop
=
0.
,
attn_drop
=
0.
,
drop_path
=
0.
,
norm_layer
=
nn
.
LayerNorm
,
downsample
=
None
,
use_checkpoint
=
False
):
super
().
__init__
()
self
.
dim
=
dim
self
.
input_resolution
=
input_resolution
self
.
depth
=
depth
self
.
use_checkpoint
=
use_checkpoint
# build blocks
self
.
blocks
=
nn
.
ModuleList
([
SwinTransformerBlock
(
dim
=
dim
,
input_resolution
=
input_resolution
,
num_heads
=
num_heads
,
window_size
=
window_size
,
shift_size
=
0
if
(
i
%
2
==
0
)
else
window_size
//
2
,
mlp_ratio
=
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop
,
attn_drop
=
attn_drop
,
drop_path
=
drop_path
[
i
]
if
isinstance
(
drop_path
,
list
)
else
drop_path
,
norm_layer
=
norm_layer
)
for
i
in
range
(
depth
)])
# patch merging layer
if
downsample
is
not
None
:
self
.
downsample
=
downsample
(
input_resolution
,
dim
=
dim
,
norm_layer
=
norm_layer
)
else
:
self
.
downsample
=
None
def
forward
(
self
,
x
):
for
blk
in
self
.
blocks
:
if
self
.
use_checkpoint
:
x
=
checkpoint
.
checkpoint
(
blk
,
x
)
else
:
x
=
blk
(
x
)
x_b4_ds
=
x
if
self
.
downsample
is
not
None
:
x
=
self
.
downsample
(
x
)
return
x_b4_ds
,
x
def
extra_repr
(
self
)
->
str
:
return
f
"dim=
{
self
.
dim
}
, input_resolution=
{
self
.
input_resolution
}
, depth=
{
self
.
depth
}
"
def
flops
(
self
):
flops
=
0
for
blk
in
self
.
blocks
:
flops
+=
blk
.
flops
()
if
self
.
downsample
is
not
None
:
flops
+=
self
.
downsample
.
flops
()
return
flops
class
PatchEmbed
(
nn
.
Module
):
r
""" Image to Patch Embedding
Args:
img_size (int): Image size. Default: 224.
patch_size (int): Patch token size. Default: 4.
in_chans (int): Number of input image channels. Default: 3.
embed_dim (int): Number of linear projection output channels. Default: 96.
norm_layer (nn.Module, optional): Normalization layer. Default: None
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
norm_layer
=
None
):
super
().
__init__
()
img_size
=
to_2tuple
(
img_size
)
patch_size
=
to_2tuple
(
patch_size
)
patches_resolution
=
[
img_size
[
0
]
//
patch_size
[
0
],
img_size
[
1
]
//
patch_size
[
1
]]
self
.
img_size
=
img_size
self
.
patch_size
=
patch_size
self
.
patches_resolution
=
patches_resolution
self
.
num_patches
=
patches_resolution
[
0
]
*
patches_resolution
[
1
]
self
.
in_chans
=
in_chans
self
.
embed_dim
=
embed_dim
self
.
proj
=
nn
.
Conv2d
(
in_chans
,
embed_dim
,
kernel_size
=
patch_size
,
stride
=
patch_size
)
if
norm_layer
is
not
None
:
self
.
norm
=
norm_layer
(
embed_dim
)
else
:
self
.
norm
=
None
def
forward
(
self
,
x
):
B
,
C
,
H
,
W
=
x
.
shape
# FIXME look at relaxing size constraints
assert
H
==
self
.
img_size
[
0
]
and
W
==
self
.
img_size
[
1
],
\
f
"Input image size (
{
H
}
*
{
W
}
) doesn't match model (
{
self
.
img_size
[
0
]
}
*
{
self
.
img_size
[
1
]
}
)."
x
=
self
.
proj
(
x
).
flatten
(
2
).
transpose
(
1
,
2
)
# B Ph*Pw C
if
self
.
norm
is
not
None
:
x
=
self
.
norm
(
x
)
return
x
def
flops
(
self
):
Ho
,
Wo
=
self
.
patches_resolution
flops
=
Ho
*
Wo
*
self
.
embed_dim
*
self
.
in_chans
*
(
self
.
patch_size
[
0
]
*
self
.
patch_size
[
1
])
if
self
.
norm
is
not
None
:
flops
+=
Ho
*
Wo
*
self
.
embed_dim
return
flops
class
SwinTransformer
(
nn
.
Module
):
r
""" Swin Transformer
A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
https://arxiv.org/pdf/2103.14030
Args:
img_size (int | tuple(int)): Input image size. Default 224
patch_size (int | tuple(int)): Patch size. Default: 4
in_chans (int): Number of input image channels. Default: 3
embed_dim (int): Patch embedding dimension. Default: 96
depths (tuple(int)): Depth of each Swin Transformer layer.
num_heads (tuple(int)): Number of attention heads in different layers.
window_size (int): Window size. Default: 7
mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
qk_scale (float): Override default qk scale of head_dim ** -0.5 if set. Default: None
drop_rate (float): Dropout rate. Default: 0
attn_drop_rate (float): Attention dropout rate. Default: 0
drop_path_rate (float): Stochastic depth rate. Default: 0.1
norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
patch_norm (bool): If True, add normalization after patch embedding. Default: True
use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
"""
def
__init__
(
self
,
img_size
=
224
,
patch_size
=
4
,
in_chans
=
3
,
embed_dim
=
96
,
depths
=
[
2
,
2
,
6
,
2
],
num_heads
=
[
3
,
6
,
12
,
24
],
window_size
=
7
,
mlp_ratio
=
4.
,
qkv_bias
=
True
,
qk_scale
=
None
,
drop_rate
=
0.
,
attn_drop_rate
=
0.
,
drop_path_rate
=
0.3
,
norm_layer
=
partial
(
nn
.
LayerNorm
,
eps
=
1e-6
),
ape
=
False
,
patch_norm
=
True
,
use_checkpoint
=
False
,
output_avg
=
False
,
**
kwargs
):
super
().
__init__
()
self
.
num_layers
=
len
(
depths
)
self
.
embed_dim
=
embed_dim
self
.
ape
=
ape
self
.
patch_norm
=
patch_norm
self
.
num_features
=
int
(
embed_dim
*
2
**
(
self
.
num_layers
-
1
))
self
.
mlp_ratio
=
mlp_ratio
self
.
img_size
=
to_2tuple
(
img_size
)
self
.
patch_size
=
to_2tuple
(
patch_size
)
self
.
output_avg
=
output_avg
# split image into non-overlapping patches
self
.
patch_embed
=
PatchEmbed
(
img_size
=
img_size
,
patch_size
=
patch_size
,
in_chans
=
in_chans
,
embed_dim
=
embed_dim
,
norm_layer
=
norm_layer
if
self
.
patch_norm
else
None
)
num_patches
=
self
.
patch_embed
.
num_patches
patches_resolution
=
self
.
patch_embed
.
patches_resolution
self
.
patches_resolution
=
patches_resolution
# absolute position embedding
if
self
.
ape
:
self
.
absolute_pos_embed
=
nn
.
Parameter
(
torch
.
zeros
(
1
,
num_patches
,
embed_dim
))
trunc_normal_
(
self
.
absolute_pos_embed
,
std
=
.
02
)
self
.
pos_drop
=
nn
.
Dropout
(
p
=
drop_rate
)
# stochastic depth
dpr
=
[
x
.
item
()
for
x
in
torch
.
linspace
(
0
,
drop_path_rate
,
sum
(
depths
))]
# stochastic depth decay rule
# build layers
self
.
layers
=
nn
.
ModuleList
()
for
i_layer
in
range
(
self
.
num_layers
):
layer
=
BasicLayer
(
dim
=
int
(
embed_dim
*
2
**
i_layer
),
input_resolution
=
(
patches_resolution
[
0
]
//
(
2
**
i_layer
),
patches_resolution
[
1
]
//
(
2
**
i_layer
)),
depth
=
depths
[
i_layer
],
num_heads
=
num_heads
[
i_layer
],
window_size
=
window_size
,
mlp_ratio
=
self
.
mlp_ratio
,
qkv_bias
=
qkv_bias
,
qk_scale
=
qk_scale
,
drop
=
drop_rate
,
attn_drop
=
attn_drop_rate
,
drop_path
=
dpr
[
sum
(
depths
[:
i_layer
]):
sum
(
depths
[:
i_layer
+
1
])],
norm_layer
=
norm_layer
,
downsample
=
PatchMerging
if
(
i_layer
<
self
.
num_layers
-
1
)
else
None
,
use_checkpoint
=
use_checkpoint
)
self
.
layers
.
append
(
layer
)
self
.
apply
(
self
.
_init_weights
)
def
_init_weights
(
self
,
m
):
if
isinstance
(
m
,
nn
.
Linear
):
trunc_normal_
(
m
.
weight
,
std
=
.
02
)
if
isinstance
(
m
,
nn
.
Linear
)
and
m
.
bias
is
not
None
:
nn
.
init
.
constant_
(
m
.
bias
,
0
)
elif
isinstance
(
m
,
nn
.
LayerNorm
):
nn
.
init
.
constant_
(
m
.
bias
,
0
)
nn
.
init
.
constant_
(
m
.
weight
,
1.0
)
@
torch
.
jit
.
ignore
def
no_weight_decay
(
self
):
return
{
'absolute_pos_embed'
}
@
torch
.
jit
.
ignore
def
no_weight_decay_keywords
(
self
):
return
{
'relative_position_bias_table'
}
def
forward
(
self
,
x
):
x
=
self
.
patch_embed
(
x
)
if
self
.
ape
:
x
=
x
+
self
.
absolute_pos_embed
x
=
self
.
pos_drop
(
x
)
h
=
self
.
img_size
[
0
]
//
self
.
patch_size
[
0
]
w
=
self
.
img_size
[
1
]
//
self
.
patch_size
[
1
]
outs
=
[]
for
i
,
layer
in
enumerate
(
self
.
layers
):
px
,
x
=
layer
(
x
)
b
,
n
,
c
=
px
.
shape
if
i
!=
len
(
self
.
layers
)
-
1
or
not
self
.
output_avg
:
px
=
px
.
permute
(
0
,
2
,
1
).
contiguous
()
px
=
px
.
reshape
(
b
,
c
,
h
,
w
)
# is this a fair assumption ?? i think it's baked into the architecture
h
,
w
=
h
//
2
,
w
//
2
outs
.
append
(
px
)
if
self
.
output_avg
:
return
outs
[
-
1
].
mean
(
dim
=
1
)
return
outs
def
flops
(
self
):
flops
=
0
flops
+=
self
.
patch_embed
.
flops
()
for
i
,
layer
in
enumerate
(
self
.
layers
):
flops
+=
layer
.
flops
()
flops
+=
self
.
num_features
*
self
.
patches_resolution
[
0
]
*
self
.
patches_resolution
[
1
]
//
(
2
**
self
.
num_layers
)
flops
+=
self
.
num_features
*
self
.
num_classes
return
flops
def
get_swin
(
drop_path_rate
=
0.3
,
output_avg
=
False
):
args
=
get_args
()
window_size
=
7
embed_dim
=
128
depths
=
[
2
,
2
,
18
,
2
]
num_heads
=
[
4
,
8
,
16
,
32
]
swin
=
SwinTransformer
(
img_size
=
(
args
.
img_h
,
args
.
img_w
,),
in_chans
=
3
,
patch_size
=
args
.
patch_dim
,
embed_dim
=
embed_dim
,
depths
=
depths
,
num_heads
=
num_heads
,
window_size
=
window_size
,
drop_path_rate
=
drop_path_rate
,
output_avg
=
output_avg
,
)
return
swin
megatron/model/vision/utils.py
0 → 100644
View file @
d3dd8642
import
warnings
import
torch
import
torch.nn.functional
as
F
def
resize
(
input
,
size
=
None
,
scale_factor
=
None
,
mode
=
'nearest'
,
align_corners
=
None
,
warning
=
True
):
if
warning
:
if
size
is
not
None
and
align_corners
:
input_h
,
input_w
=
tuple
(
int
(
x
)
for
x
in
input
.
shape
[
2
:])
output_h
,
output_w
=
tuple
(
int
(
x
)
for
x
in
size
)
if
output_h
>
input_h
or
output_w
>
output_h
:
if
((
output_h
>
1
and
output_w
>
1
and
input_h
>
1
and
input_w
>
1
)
and
(
output_h
-
1
)
%
(
input_h
-
1
)
and
(
output_w
-
1
)
%
(
input_w
-
1
)):
warnings
.
warn
(
f
'When align_corners=
{
align_corners
}
, '
'the output would more aligned if '
f
'input size
{
(
input_h
,
input_w
)
}
is `x+1` and '
f
'out size
{
(
output_h
,
output_w
)
}
is `nx+1`'
)
if
isinstance
(
size
,
torch
.
Size
):
size
=
tuple
(
int
(
x
)
for
x
in
size
)
return
F
.
interpolate
(
input
,
size
,
scale_factor
,
mode
,
align_corners
)
megatron/model/vision/vit_backbone.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Vision Transformer(VIT) model."""
import
math
import
einops
import
torch
import
apex
import
torch.nn.functional
as
F
from
megatron
import
get_args
from
megatron.model.transformer
import
ParallelTransformer
from
megatron.model.utils
import
(
get_linear_layer
,
init_method_normal
,
scaled_init_method_normal
,
)
from
megatron.model.module
import
MegatronModule
CLASS_TOKEN_LENGTH
=
8
class
VitMlpHead
(
MegatronModule
):
"""Pooler layer.
Pool hidden states of a specific token (for example start of the
sequence) and add a linear transformation followed by a tanh.
Arguments:
hidden_size: hidden size
init_method: weight initialization method for the linear layer.
bias is set to zero.
"""
def
__init__
(
self
,
hidden_size
,
num_classes
):
super
(
VitMlpHead
,
self
).
__init__
()
self
.
dense_in
=
torch
.
nn
.
Linear
(
hidden_size
,
hidden_size
)
self
.
relu
=
torch
.
nn
.
ReLU
()
self
.
dense_out
=
torch
.
nn
.
Linear
(
hidden_size
,
num_classes
)
torch
.
nn
.
init
.
constant_
(
self
.
dense_out
.
bias
,
-
10
)
def
forward
(
self
,
hidden_states
):
# hidden_states: [b, 1, h]
# sequence_index: index of the token to pool.
dense_in_result
=
self
.
dense_in
(
hidden_states
)
tanh_result
=
torch
.
tanh
(
dense_in_result
)
dense_out_result
=
self
.
dense_out
(
tanh_result
)
return
dense_out_result
def
isPerfectSquare
(
x
):
if
(
x
>=
0
):
sr
=
math
.
sqrt
(
x
)
return
(
int
(
sr
)
*
int
(
sr
)
==
x
)
return
False
def
twod_interpolate_position_embeddings_hook
(
state_dict
,
prefix
,
local_metadata
,
strict
,
missing_keys
,
unexpected_keys
,
error_msgs
,
):
args
=
get_args
()
num_patches_per_dim_h
=
args
.
img_h
//
args
.
patch_dim
num_patches_per_dim_w
=
args
.
img_w
//
args
.
patch_dim
num_patches
=
num_patches_per_dim_h
*
num_patches_per_dim_w
hidden_size
=
args
.
hidden_size
key
=
prefix
+
"weight"
assert
key
in
state_dict
if
key
in
state_dict
:
input_param
=
state_dict
[
key
]
input_seq_len
=
input_param
.
shape
[
0
]
assert
(
isPerfectSquare
(
input_seq_len
)
or
isPerfectSquare
(
input_seq_len
-
CLASS_TOKEN_LENGTH
))
input_has_class_token
=
not
isPerfectSquare
(
input_seq_len
)
num_tok_input
=
input_seq_len
-
CLASS_TOKEN_LENGTH
if
input_has_class_token
else
input_seq_len
num_tok_output
=
num_patches
output_has_class_token
=
args
.
class_token_present
# update input_param and load it to state_dict[key]
if
input_has_class_token
:
input_param_tok
=
input_param
[:
CLASS_TOKEN_LENGTH
,
:]
input_param_grid
=
input_param
[
CLASS_TOKEN_LENGTH
:,
:]
else
:
input_param_tok
=
torch
.
zeros
(
CLASS_TOKEN_LENGTH
,
hidden_size
)
input_param_grid
=
input_param
assert
input_param
.
shape
[
1
]
==
hidden_size
if
num_tok_input
!=
num_tok_output
:
gs_input
=
int
(
math
.
sqrt
(
num_tok_input
))
gs_new
=
(
num_patches_per_dim_h
,
num_patches_per_dim_w
)
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
input_param_grid
=
input_param_grid
.
reshape
(
(
1
,
-
1
,
gs_input
,
gs_input
)
)
input_param_grid
=
input_param_grid
.
float
()
scale_factor
=
(
gs_new
[
0
]
/
gs_input
,
gs_new
[
1
]
/
gs_input
)
input_param_grid
=
F
.
interpolate
(
input_param_grid
,
scale_factor
=
scale_factor
,
mode
=
"bilinear"
)
input_param_grid
=
input_param_grid
.
half
()
input_param_grid
=
input_param_grid
.
reshape
((
-
1
,
num_tok_output
))
input_param_grid
=
input_param_grid
.
transpose
(
0
,
1
).
contiguous
()
assert
input_param_grid
.
shape
[
1
]
==
hidden_size
input_param
=
input_param_grid
assert
(
input_param
.
shape
[
0
]
==
num_tok_output
and
input_param
.
shape
[
1
]
==
hidden_size
)
if
output_has_class_token
:
input_param
=
torch
.
cat
((
input_param_tok
,
input_param
),
dim
=
0
)
state_dict
[
key
]
=
input_param
class
VitBackbone
(
MegatronModule
):
"""Vision Transformer Model."""
def
__init__
(
self
,
config
,
pre_process
=
True
,
post_process
=
True
,
class_token
=
True
,
single_token_output
=
False
,
post_layer_norm
=
True
,
drop_path_rate
=
0.0
):
super
(
VitBackbone
,
self
).
__init__
(
share_embeddings_and_output_weights
=
False
)
args
=
get_args
()
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
class_token
=
class_token
self
.
post_layer_norm
=
post_layer_norm
self
.
hidden_size
=
args
.
hidden_size
self
.
patch_dim
=
args
.
patch_dim
self
.
img_h
=
args
.
img_h
self
.
img_w
=
args
.
img_w
self
.
micro_batch_size
=
args
.
micro_batch_size
self
.
single_token_output
=
single_token_output
self
.
drop_path_rate
=
drop_path_rate
assert
self
.
img_h
%
self
.
patch_dim
==
0
assert
self
.
img_w
%
self
.
patch_dim
==
0
self
.
num_patches_per_dim_h
=
self
.
img_h
//
self
.
patch_dim
self
.
num_patches_per_dim_w
=
self
.
img_w
//
self
.
patch_dim
self
.
num_patches
=
self
.
num_patches_per_dim_h
*
self
.
num_patches_per_dim_w
self
.
seq_length
=
self
.
num_patches
+
(
CLASS_TOKEN_LENGTH
if
self
.
class_token
else
0
)
self
.
flatten_dim
=
self
.
patch_dim
*
self
.
patch_dim
*
args
.
num_channels
self
.
input_tensor
=
None
self
.
position_ids
=
None
if
self
.
pre_process
:
# cls_token
if
self
.
class_token
:
self
.
cls_token
=
torch
.
nn
.
Parameter
(
torch
.
randn
(
1
,
CLASS_TOKEN_LENGTH
,
self
.
hidden_size
)
)
torch
.
nn
.
init
.
zeros_
(
self
.
cls_token
)
self
.
position_ids
=
torch
.
arange
(
self
.
seq_length
).
expand
(
1
,
-
1
).
cuda
()
# Linear encoder
self
.
linear_encoder
=
torch
.
nn
.
Linear
(
self
.
flatten_dim
,
self
.
hidden_size
)
# embedding
self
.
position_embeddings
=
torch
.
nn
.
Embedding
(
self
.
seq_length
,
self
.
hidden_size
)
init_method_normal
(
args
.
init_method_std
)(
self
.
position_embeddings
.
weight
)
args
.
class_token_present
=
self
.
class_token
self
.
position_embeddings
.
_register_load_state_dict_pre_hook
(
twod_interpolate_position_embeddings_hook
)
self
.
embedding_dropout
=
torch
.
nn
.
Dropout
(
args
.
hidden_dropout
)
# Transformer
self
.
transformer
=
ParallelTransformer
(
config
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
,
post_layer_norm
=
self
.
post_layer_norm
,
drop_path_rate
=
self
.
drop_path_rate
)
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
transformer
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input
):
if
self
.
pre_process
:
rearranged_input
=
einops
.
rearrange
(
input
,
"b c (h p1) (w p2) -> b (h w) (p1 p2 c)"
,
p1
=
self
.
patch_dim
,
p2
=
self
.
patch_dim
,
)
assert
rearranged_input
.
dtype
==
torch
.
half
encoder_output
=
self
.
linear_encoder
(
rearranged_input
)
concatenated_tokens
=
encoder_output
if
self
.
class_token
:
cls_tokens
=
self
.
cls_token
.
expand
(
encoder_output
.
shape
[
0
],
-
1
,
-
1
)
concatenated_tokens
=
torch
.
cat
((
cls_tokens
,
encoder_output
),
dim
=
1
)
token_embeddings
=
concatenated_tokens
+
\
self
.
position_embeddings
(
self
.
position_ids
[:,
:
concatenated_tokens
.
shape
[
1
]])
# [b, s, h] => [s, b, h]
token_embeddings
=
token_embeddings
.
transpose
(
0
,
1
).
contiguous
()
hidden_states
=
self
.
embedding_dropout
(
token_embeddings
)
else
:
hidden_states
=
input
hidden_states
=
self
.
transformer
(
hidden_states
,
None
)
if
self
.
post_process
:
# [s b h] => [b s h]
if
self
.
single_token_output
:
hidden_states
=
hidden_states
[
0
]
else
:
hidden_states
=
hidden_states
.
transpose
(
0
,
1
).
contiguous
()
return
hidden_states
megatron/model/yuan_hf_model.py
0 → 100644
View file @
d3dd8642
# coding=utf-8
# Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved.
#
# This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX
# and OPT implementations in this library. It has been modified from its
# original forms to accommodate minor architectural differences compared
# to GPT-NeoX and OPT used by the Meta AI team that trained the model.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" PyTorch LLaMA model."""
import
math
from
typing
import
List
,
Optional
,
Tuple
,
Union
import
torch.nn.functional
as
F
import
torch
import
torch.utils.checkpoint
from
torch
import
nn
from
torch.nn
import
BCEWithLogitsLoss
,
CrossEntropyLoss
,
MSELoss
from
transformers.activations
import
ACT2FN
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
,
CausalLMOutputWithPast
,
SequenceClassifierOutputWithPast
from
transformers.modeling_utils
import
PreTrainedModel
from
transformers.utils
import
add_start_docstrings
,
add_start_docstrings_to_model_forward
,
logging
,
replace_return_docstrings
from
transformers.models.llama.configuration_llama
import
LlamaConfig
from
einops
import
rearrange
from
flash_attn
import
flash_attn_varlen_func
as
flash_attn_unpadded_func
from
flash_attn
import
flash_attn_func
import
copy
# from flash_attn.flash_attn_interface import flash_attn_unpadded_func
logger
=
logging
.
get_logger
(
__name__
)
_CONFIG_FOR_DOC
=
"LlamaConfig"
class
LocalizedFiltering
(
torch
.
nn
.
Module
):
"""
Mega's Exponential Moving Average layer, largely left unmodified from the original repo with the exception of
variable names and moving away from the stateful representation of incremental decoding state. See
"https://arxiv.org/abs/2209.10655" for more details.
"""
def
__init__
(
self
,
hidden_size
):
super
().
__init__
()
self
.
embed_dim
=
hidden_size
self
.
lf_conv2d_group
=
1
self
.
lf_conv2d_num_pad
=
1
self
.
conv1
=
torch
.
nn
.
Conv2d
(
self
.
embed_dim
,
self
.
embed_dim
//
2
,
(
2
,
1
),
stride
=
(
1
,
1
),
padding
=
(
self
.
lf_conv2d_num_pad
,
0
),
groups
=
self
.
lf_conv2d_group
)
self
.
conv2
=
torch
.
nn
.
Conv2d
(
self
.
embed_dim
//
2
,
self
.
embed_dim
,
(
2
,
1
),
stride
=
(
1
,
1
),
padding
=
(
self
.
lf_conv2d_num_pad
,
0
),
groups
=
self
.
lf_conv2d_group
)
self
.
output_layernorm
=
LlamaRMSNorm
(
self
.
embed_dim
)
# renamed delta (damping_factor) and alpha (decay_factor) to be more descriptive of what the parameters are doing
##self.damping_factor = torch.nn.Parameter(torch.Tensor(1, 1, self.embed_dim))
##self.decay_factor = torch.nn.Parameter(torch.Tensor(1, 1, self.embed_dim))
def
_train_forward
(
self
,
inputs
):
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
inputs
=
inputs
.
transpose
(
0
,
1
)
seq_len
,
bsz
,
embed_dim
=
inputs
.
size
()
if
embed_dim
!=
self
.
embed_dim
:
raise
ValueError
(
f
"Unexpected embedding dimension received: input is
{
embed_dim
}
, model expects
{
self
.
embed_dim
}
"
)
residual
=
inputs
# (sequence_length x batch_size x hidden_size) -> (batch_size x sequence_length x hidden_size)
#inputs = inputs.permute(1, 0, 2)
inputs
=
inputs
.
view
(
seq_len
,
1
,
bsz
,
embed_dim
).
permute
(
2
,
3
,
0
,
1
)
output1
=
self
.
conv1
(
inputs
)
output1
=
output1
[:,
:,
:
seq_len
,
:]
output2
=
self
.
conv2
(
output1
)
output2
=
output2
[:,
:,
:
seq_len
,
:].
permute
(
2
,
3
,
0
,
1
).
contiguous
()
output2
=
output2
.
view
(
seq_len
,
bsz
,
embed_dim
)
assert
output2
.
shape
==
residual
.
shape
lf_output
=
self
.
output_layernorm
(
output2
+
residual
)
lf_output
=
lf_output
.
transpose
(
0
,
1
)
return
lf_output
def
_inference_forward
(
self
,
inputs
,
before_hidden_states
):
if
before_hidden_states
is
None
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
inputs
=
inputs
.
transpose
(
0
,
1
)
seq_len
,
bsz
,
embed_dim
=
inputs
.
size
()
if
embed_dim
!=
self
.
embed_dim
:
raise
ValueError
(
f
"Unexpected embedding dimension received: input is
{
embed_dim
}
, model expects
{
self
.
embed_dim
}
"
)
residual
=
inputs
# (sequence_length x batch_size x hidden_size) -> (batch_size x sequence_length x hidden_size)
#inputs = inputs.permute(1, 0, 2)
inputs
=
inputs
.
view
(
seq_len
,
1
,
bsz
,
embed_dim
).
permute
(
2
,
3
,
0
,
1
)
output1
=
self
.
conv1
(
inputs
)
output1
=
output1
[:,
:,
:
seq_len
,
:]
output2
=
self
.
conv2
(
output1
)
output2
=
output2
[:,
:,
:
seq_len
,
:].
permute
(
2
,
3
,
0
,
1
).
contiguous
()
output2
=
output2
.
view
(
seq_len
,
bsz
,
embed_dim
)
assert
output2
.
shape
==
residual
.
shape
lf_output
=
self
.
output_layernorm
(
output2
+
residual
)
lf_output
=
lf_output
.
transpose
(
0
,
1
)
return
lf_output
else
:
inputs
=
inputs
.
transpose
(
0
,
1
)
before_hidden_states
=
before_hidden_states
.
transpose
(
0
,
1
)
residual
=
inputs
seq_len
,
bsz
,
embed_dim
=
inputs
.
size
()
seq_len_before
,
_
,
_
=
before_hidden_states
.
size
()
assert
seq_len
==
1
and
seq_len_before
==
2
inputs
=
torch
.
cat
((
before_hidden_states
,
inputs
),
dim
=
0
)
inputs
=
inputs
.
view
(
3
,
1
,
bsz
,
embed_dim
).
permute
(
2
,
3
,
0
,
1
)
output1
=
self
.
conv1
(
inputs
)
output2
=
self
.
conv2
(
output1
[:,:,
1
:
-
1
,:])
output2
=
output2
[:,:,
1
:
-
1
,:]
output2
=
output2
.
view
(
1
,
bsz
,
embed_dim
)
assert
output2
.
shape
==
residual
.
shape
lf_output
=
self
.
output_layernorm
(
output2
+
residual
)
lf_output
=
lf_output
.
transpose
(
0
,
1
)
return
lf_output
def
forward
(
self
,
inputs
,
before_hidden_states
)
->
torch
.
Tensor
:
assert
self
.
lf_conv2d_num_pad
==
1
if
self
.
training
:
lf_output
=
self
.
_train_forward
(
inputs
)
else
:
lf_output
=
self
.
_inference_forward
(
inputs
,
before_hidden_states
)
return
lf_output
# Copied from transformers.models.bart.modeling_bart._make_causal_mask
def
_make_causal_mask
(
input_ids_shape
:
torch
.
Size
,
dtype
:
torch
.
dtype
,
device
:
torch
.
device
,
past_key_values_length
:
int
=
0
):
"""
Make causal mask used for bi-directional self-attention.
"""
bsz
,
tgt_len
=
input_ids_shape
mask
=
torch
.
full
((
tgt_len
,
tgt_len
),
torch
.
tensor
(
torch
.
finfo
(
dtype
).
min
,
device
=
device
),
device
=
device
)
mask_cond
=
torch
.
arange
(
mask
.
size
(
-
1
),
device
=
device
)
mask
.
masked_fill_
(
mask_cond
<
(
mask_cond
+
1
).
view
(
mask
.
size
(
-
1
),
1
),
0
)
mask
=
mask
.
to
(
dtype
)
if
past_key_values_length
>
0
:
mask
=
torch
.
cat
([
torch
.
zeros
(
tgt_len
,
past_key_values_length
,
dtype
=
dtype
,
device
=
device
),
mask
],
dim
=-
1
)
return
mask
[
None
,
None
,
:,
:].
expand
(
bsz
,
1
,
tgt_len
,
tgt_len
+
past_key_values_length
)
# Copied from transformers.models.bart.modeling_bart._expand_mask
def
_expand_mask
(
mask
:
torch
.
Tensor
,
dtype
:
torch
.
dtype
,
tgt_len
:
Optional
[
int
]
=
None
):
"""
Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len, src_seq_len]`.
"""
bsz
,
src_len
=
mask
.
size
()
tgt_len
=
tgt_len
if
tgt_len
is
not
None
else
src_len
expanded_mask
=
mask
[:,
None
,
None
,
:].
expand
(
bsz
,
1
,
tgt_len
,
src_len
).
to
(
dtype
)
inverted_mask
=
1.0
-
expanded_mask
return
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
dtype
).
min
)
class
LlamaRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
eps
=
1e-6
):
"""
LlamaRMSNorm is equivalent to T5LayerNorm
"""
super
().
__init__
()
self
.
weight
=
nn
.
Parameter
(
torch
.
ones
(
hidden_size
))
self
.
variance_epsilon
=
eps
def
forward
(
self
,
hidden_states
):
variance
=
hidden_states
.
to
(
torch
.
float32
).
pow
(
2
).
mean
(
-
1
,
keepdim
=
True
)
hidden_states
=
hidden_states
*
torch
.
rsqrt
(
variance
+
self
.
variance_epsilon
)
# convert into half-precision if necessary
if
self
.
weight
.
dtype
in
[
torch
.
float16
,
torch
.
bfloat16
]:
hidden_states
=
hidden_states
.
to
(
self
.
weight
.
dtype
)
return
self
.
weight
*
hidden_states
class
LlamaRotaryEmbedding
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
super
().
__init__
()
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
().
to
(
device
)
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
# Build here to make `torch.jit.trace` work.
self
.
max_seq_len_cached
=
max_position_embeddings
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
self
.
inv_freq
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
)
self
.
register_buffer
(
"cos_cached"
,
emb
.
cos
()[
None
,
None
,
:,
:],
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
()[
None
,
None
,
:,
:],
persistent
=
False
)
def
forward
(
self
,
x
,
seq_len
=
None
):
# x: [bs, num_attention_heads, seq_len, head_size]
# This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
'''if seq_len > self.max_seq_len_cached:
self.max_seq_len_cached = seq_len
t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)'''
if
seq_len
>
self
.
max_seq_len_cached
:
self
.
max_seq_len_cached
=
seq_len
t
=
torch
.
arange
(
self
.
max_seq_len_cached
,
device
=
x
.
device
,
dtype
=
self
.
inv_freq
.
dtype
)
freqs
=
torch
.
einsum
(
"i,j->ij"
,
t
,
self
.
inv_freq
)
# Different from paper, but it uses a different permutation in order to obtain the same calculation
emb
=
torch
.
cat
((
freqs
,
freqs
),
dim
=-
1
).
to
(
x
.
device
)
self
.
register_buffer
(
"cos_cached"
,
emb
.
cos
()[
None
,
None
,
:,
:],
persistent
=
False
)
self
.
register_buffer
(
"sin_cached"
,
emb
.
sin
()[
None
,
None
,
:,
:],
persistent
=
False
)
return
(
self
.
cos_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
self
.
sin_cached
[:,
:,
:
seq_len
,
...].
to
(
dtype
=
x
.
dtype
),
)
def
rotate_half
(
x
):
"""Rotates half the hidden dims of the input."""
x1
=
x
[...,
:
x
.
shape
[
-
1
]
//
2
]
x2
=
x
[...,
x
.
shape
[
-
1
]
//
2
:]
return
torch
.
cat
((
-
x2
,
x1
),
dim
=-
1
)
def
apply_rotary_pos_emb
(
q
,
k
,
cos
,
sin
,
position_ids
):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos
=
cos
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
sin
=
sin
.
squeeze
(
1
).
squeeze
(
0
)
# [seq_len, dim]
cos
=
cos
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
sin
=
sin
[
position_ids
].
unsqueeze
(
1
)
# [bs, 1, seq_len, dim]
q_embed
=
(
q
*
cos
)
+
(
rotate_half
(
q
)
*
sin
)
k_embed
=
(
k
*
cos
)
+
(
rotate_half
(
k
)
*
sin
)
return
q_embed
,
k_embed
class
LlamaMLP
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
:
int
,
intermediate_size
:
int
,
hidden_act
:
str
,
):
super
().
__init__
()
self
.
up_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
)
self
.
gate_proj
=
nn
.
Linear
(
hidden_size
,
intermediate_size
,
bias
=
False
)
self
.
down_proj
=
nn
.
Linear
(
intermediate_size
,
hidden_size
,
bias
=
False
)
self
.
act_fn
=
ACT2FN
[
hidden_act
]
def
forward
(
self
,
x
):
return
self
.
down_proj
(
self
.
gate_proj
(
x
)
*
self
.
act_fn
(
self
.
up_proj
(
x
)))
class
LlamaAttention
(
nn
.
Module
):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
()
self
.
config
=
config
self
.
hidden_size
=
config
.
hidden_size
self
.
num_heads
=
config
.
num_attention_heads
self
.
head_dim
=
self
.
hidden_size
//
self
.
num_heads
self
.
max_position_embeddings
=
config
.
max_position_embeddings
self
.
causal_mask
=
config
.
causal_mask
self
.
softmax_scale
=
1.0
/
math
.
sqrt
(
self
.
head_dim
)
self
.
use_flash_attention
=
config
.
use_flash_attention
try
:
self
.
use_shareqk
=
config
.
use_shareqk
except
Exception
as
e
:
self
.
use_shareqk
=
False
self
.
dropout
=
0.0
if
(
self
.
head_dim
*
self
.
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
self
.
num_heads
}
)."
)
self
.
v_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
False
)
self
.
o_proj
=
nn
.
Linear
(
self
.
num_heads
*
self
.
head_dim
,
self
.
hidden_size
,
bias
=
False
)
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
self
.
max_position_embeddings
)
if
self
.
use_shareqk
:
self
.
qk_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
False
)
self
.
qk_weight
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
hidden_size
))
self
.
qk_bias
=
nn
.
Parameter
(
torch
.
Tensor
(
2
,
self
.
hidden_size
))
else
:
self
.
lf_gate
=
LocalizedFiltering
(
self
.
hidden_size
)
self
.
q_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
False
)
self
.
k_proj
=
nn
.
Linear
(
self
.
hidden_size
,
self
.
num_heads
*
self
.
head_dim
,
bias
=
False
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
)
->
Tuple
[
torch
.
Tensor
,
Optional
[
torch
.
Tensor
],
Optional
[
Tuple
[
torch
.
Tensor
]]]:
bsz
,
q_len
,
_
=
hidden_states
.
size
()
before_hidden_states
=
None
is_first_step
=
False
if
use_cache
:
if
past_key_value
is
None
:
inference_hidden_states_memory
=
torch
.
empty
(
bsz
,
2
,
hidden_states
.
shape
[
2
],
dtype
=
hidden_states
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
is_first_step
=
True
else
:
before_hidden_states
=
past_key_value
[
2
]
if
use_cache
:
if
is_first_step
:
if
q_len
>=
2
:
inference_hidden_states_memory
=
hidden_states
[
:,
-
2
:,
:]
else
:
inference_hidden_states_memory
[:,
:,
:]
=
0
inference_hidden_states_memory
[:,
-
1
:,
:]
=
hidden_states
[:,
-
1
:,
:]
else
:
hidden_states_tmp
=
before_hidden_states
[:,
-
1
:,
:]
inference_hidden_states_memory
=
copy
.
deepcopy
(
torch
.
cat
((
hidden_states_tmp
,
hidden_states
),
dim
=
1
))
#inference_hidden_states_memory[:, :1, :] = hidden_states_tmp
#inference_hidden_states_memory[:, 1:, :] = hidden_states
value_states
=
self
.
v_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
if
self
.
use_shareqk
:
qk_states
=
self
.
qk_proj
(
hidden_states
).
view
(
bsz
,
q_len
,
self
.
num_heads
*
self
.
head_dim
)
#qk_states = F.silu(qk_states)
query_key
=
qk_states
.
unsqueeze
(
2
)
*
self
.
qk_weight
+
self
.
qk_bias
query_states
,
key_states
=
torch
.
unbind
(
query_key
,
dim
=
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
else
:
hidden_states
=
self
.
lf_gate
(
hidden_states
,
before_hidden_states
)
query_states
=
self
.
q_proj
(
hidden_states
)
key_states
=
self
.
k_proj
(
hidden_states
)
qk_states
=
torch
.
cat
([
query_states
,
key_states
],
dim
=-
1
)
qk_states
=
qk_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
int
(
qk_states
.
shape
[
-
1
]
//
self
.
num_heads
))
(
query_states
,
key_states
)
=
torch
.
chunk
(
qk_states
,
2
,
dim
=-
1
)
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
# [bsz, nh, t, hd]
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
past_key_value
=
(
key_states
,
value_states
,
inference_hidden_states_memory
)
if
use_cache
else
None
if
self
.
use_flash_attention
:
attn_weights
=
None
query_states
=
query_states
.
transpose
(
1
,
2
)
key_states
=
key_states
.
transpose
(
1
,
2
)
value_states
=
value_states
.
transpose
(
1
,
2
)
batch_size
,
seqlen_q
=
query_states
.
shape
[
0
],
query_states
.
shape
[
1
]
seqlen_k
=
key_states
.
shape
[
1
]
q
,
k
,
v
=
[
rearrange
(
x
,
'b s ... -> (b s) ...'
)
for
x
in
[
query_states
,
key_states
,
value_states
]]
cu_seqlens_q
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_q
,
step
=
seqlen_q
,
dtype
=
torch
.
int
,
device
=
q
.
device
)
if
self
.
training
:
assert
seqlen_k
==
seqlen_q
cu_seqlens_k
=
cu_seqlens_q
is_causal
=
self
.
causal_mask
else
:
is_causal
=
seqlen_q
==
seqlen_k
cu_seqlens_k
=
torch
.
arange
(
0
,
(
batch_size
+
1
)
*
seqlen_k
,
step
=
seqlen_k
,
dtype
=
torch
.
int
,
device
=
q
.
device
)
self
.
dropout
=
0
output
=
flash_attn_unpadded_func
(
q
,
k
,
v
,
cu_seqlens_q
,
cu_seqlens_k
,
seqlen_q
,
seqlen_k
,
self
.
dropout
,
causal
=
is_causal
)
attn_output
=
rearrange
(
output
,
'(b s) ... -> b s ...'
,
b
=
batch_size
)
else
:
attn_weights
=
torch
.
matmul
(
query_states
,
key_states
.
transpose
(
2
,
3
))
/
math
.
sqrt
(
self
.
head_dim
)
if
attn_weights
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention weights should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
kv_seq_len
)
}
, but is"
f
"
{
attn_weights
.
size
()
}
"
)
if
attention_mask
is
not
None
:
if
attention_mask
.
size
()
!=
(
bsz
,
1
,
q_len
,
kv_seq_len
):
raise
ValueError
(
f
"Attention mask should be of size
{
(
bsz
,
1
,
q_len
,
kv_seq_len
)
}
, but is
{
attention_mask
.
size
()
}
"
)
attn_weights
=
attn_weights
+
attention_mask
attn_weights
=
torch
.
max
(
attn_weights
,
torch
.
tensor
(
torch
.
finfo
(
attn_weights
.
dtype
).
min
))
# upcast attention to fp32
attn_weights
=
nn
.
functional
.
softmax
(
attn_weights
,
dim
=-
1
,
dtype
=
torch
.
float32
).
to
(
query_states
.
dtype
)
attn_output
=
torch
.
matmul
(
attn_weights
,
value_states
)
if
attn_output
.
size
()
!=
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
):
raise
ValueError
(
f
"`attn_output` should be of size
{
(
bsz
,
self
.
num_heads
,
q_len
,
self
.
head_dim
)
}
, but is"
f
"
{
attn_output
.
size
()
}
"
)
attn_output
=
attn_output
.
transpose
(
1
,
2
)
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
if
not
output_attentions
:
attn_weights
=
None
return
attn_output
,
attn_weights
,
past_key_value
class
LlamaDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
()
self
.
hidden_size
=
config
.
hidden_size
self
.
self_attn
=
LlamaAttention
(
config
=
config
)
self
.
mlp
=
LlamaMLP
(
hidden_size
=
self
.
hidden_size
,
intermediate_size
=
config
.
intermediate_size
,
hidden_act
=
config
.
hidden_act
,
)
self
.
input_layernorm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
post_attention_layernorm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_value
:
Optional
[
Tuple
[
torch
.
Tensor
]]
=
None
,
output_attentions
:
Optional
[
bool
]
=
False
,
use_cache
:
Optional
[
bool
]
=
False
,
)
->
Tuple
[
torch
.
FloatTensor
,
Optional
[
Tuple
[
torch
.
FloatTensor
,
torch
.
FloatTensor
]]]:
"""
Args:
hidden_states (`torch.FloatTensor`): input to the layer of shape `(batch, seq_len, embed_dim)`
attention_mask (`torch.FloatTensor`, *optional*): attention mask of size
`(batch, 1, tgt_len, src_len)` where padding elements are indicated by very large negative values.
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under
returned tensors for more detail.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding
(see `past_key_values`).
past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past key and value projection states
"""
residual
=
hidden_states
hidden_states
=
self
.
input_layernorm
(
hidden_states
)
# Self Attention
hidden_states
,
self_attn_weights
,
present_key_value
=
self
.
self_attn
(
hidden_states
=
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
residual
+
hidden_states
# Fully Connected
residual
=
hidden_states
hidden_states
=
self
.
post_attention_layernorm
(
hidden_states
)
hidden_states
=
self
.
mlp
(
hidden_states
)
hidden_states
=
residual
+
hidden_states
outputs
=
(
hidden_states
,)
if
output_attentions
:
outputs
+=
(
self_attn_weights
,)
if
use_cache
:
outputs
+=
(
present_key_value
,)
return
outputs
LLAMA_START_DOCSTRING
=
r
"""
This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
etc.)
This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
and behavior.
Parameters:
config ([`LlamaConfig`]):
Model configuration class with all the parameters of the model. Initializing with a config file does not
load the weights associated with the model, only the configuration. Check out the
[`~PreTrainedModel.from_pretrained`] method to load the model weights.
"""
@
add_start_docstrings
(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top."
,
LLAMA_START_DOCSTRING
,
)
class
LlamaPreTrainedModel
(
PreTrainedModel
):
config_class
=
LlamaConfig
base_model_prefix
=
"model"
supports_gradient_checkpointing
=
True
_no_split_modules
=
[
"LlamaDecoderLayer"
]
_skip_keys_device_placement
=
"past_key_values"
_keys_to_ignore_on_load_unexpected
=
[
r
"decoder\.version"
]
def
_init_weights
(
self
,
module
):
std
=
self
.
config
.
initializer_range
if
isinstance
(
module
,
nn
.
Linear
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
bias
is
not
None
:
module
.
bias
.
data
.
zero_
()
elif
isinstance
(
module
,
nn
.
Embedding
):
module
.
weight
.
data
.
normal_
(
mean
=
0.0
,
std
=
std
)
if
module
.
padding_idx
is
not
None
:
module
.
weight
.
data
[
module
.
padding_idx
].
zero_
()
def
_set_gradient_checkpointing
(
self
,
module
,
value
=
False
):
if
isinstance
(
module
,
LlamaModel
):
module
.
gradient_checkpointing
=
value
LLAMA_INPUTS_DOCSTRING
=
r
"""
Args:
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
it.
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
[What are input IDs?](../glossary#input-ids)
attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
- 1 for tokens that are **not masked**,
- 0 for tokens that are **masked**.
[What are attention masks?](../glossary#attention-mask)
Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
[`PreTrainedTokenizer.__call__`] for details.
If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
`past_key_values`).
If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
information on the default strategy.
- 1 indicates the head is **not masked**,
- 0 indicates the head is **masked**.
position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
config.n_positions - 1]`.
[What are position IDs?](../glossary#position-ids)
past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
`(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
`(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
`decoder_input_ids` of shape `(batch_size, sequence_length)`.
inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
model's internal embedding lookup matrix.
use_cache (`bool`, *optional*):
If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
`past_key_values`).
output_attentions (`bool`, *optional*):
Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
tensors for more detail.
output_hidden_states (`bool`, *optional*):
Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
more detail.
return_dict (`bool`, *optional*):
Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
"""
@
add_start_docstrings
(
"The bare LLaMA Model outputting raw hidden-states without any specific head on top."
,
LLAMA_START_DOCSTRING
,
)
class
LlamaModel
(
LlamaPreTrainedModel
):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`LlamaDecoderLayer`]
Args:
config: LlamaConfig
"""
def
__init__
(
self
,
config
:
LlamaConfig
):
super
().
__init__
(
config
)
self
.
padding_idx
=
config
.
pad_token_id
self
.
vocab_size
=
config
.
vocab_size
#TODO: control it by config
self
.
eod_token
=
config
.
eod_token
self
.
reset_attention_mask
=
config
.
reset_attention_mask
self
.
reset_position_ids
=
config
.
reset_position_ids
self
.
embed_tokens
=
nn
.
Embedding
(
config
.
vocab_size
,
config
.
hidden_size
,
self
.
padding_idx
)
self
.
layers
=
nn
.
ModuleList
([
LlamaDecoderLayer
(
config
)
for
_
in
range
(
config
.
num_hidden_layers
)])
self
.
norm
=
LlamaRMSNorm
(
config
.
hidden_size
,
eps
=
config
.
rms_norm_eps
)
self
.
gradient_checkpointing
=
False
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
embed_tokens
=
value
# Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
def
_prepare_decoder_attention_mask
(
self
,
attention_mask
,
input_shape
,
inputs_embeds
,
past_key_values_length
):
# create causal mask
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
combined_attention_mask
=
None
if
input_shape
[
-
1
]
>
1
:
combined_attention_mask
=
_make_causal_mask
(
input_shape
,
inputs_embeds
.
dtype
,
device
=
inputs_embeds
.
device
,
past_key_values_length
=
past_key_values_length
,
)
if
attention_mask
is
not
None
:
# [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
expanded_attn_mask
=
_expand_mask
(
attention_mask
,
inputs_embeds
.
dtype
,
tgt_len
=
input_shape
[
-
1
]).
to
(
inputs_embeds
.
device
)
combined_attention_mask
=
(
expanded_attn_mask
if
combined_attention_mask
is
None
else
expanded_attn_mask
+
combined_attention_mask
)
return
combined_attention_mask
def
_prepare_decoder_attention_mask_training
(
self
,
input_id
,
inputs_embeds
,
eod_token
,
reset_mask_flag
,
reset_attention_mask
=
True
,
reset_position_ids
=
True
):
micro_batch_size
,
seq_length
=
input_id
.
size
()
attention_mask
=
torch
.
tril
(
torch
.
ones
(
(
micro_batch_size
,
seq_length
,
seq_length
),
device
=
inputs_embeds
.
device
)).
view
(
micro_batch_size
,
1
,
seq_length
,
seq_length
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
inputs_embeds
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_id
)
if
reset_position_ids
:
position_ids
=
position_ids
.
clone
()
if
reset_position_ids
or
reset_attention_mask
:
# Loop through the batches:
for
b
in
range
(
micro_batch_size
):
# Find indecies where EOD token is.
eod_index
=
position_ids
[
b
,
input_id
[
b
]
==
eod_token
]
# Detach indecies from positions if going to modify positions.
if
reset_position_ids
:
eod_index
=
eod_index
.
clone
()
# Loop through EOD indecies:
prev_index
=
0
for
j
in
range
(
eod_index
.
size
()[
0
]):
i
=
eod_index
[
j
]
# Mask attention loss.
if
reset_attention_mask
:
attention_mask
[
b
,
0
,
(
i
+
1
):,
:(
i
+
1
)]
=
0
# Reset positions.
if
reset_position_ids
:
position_ids
[
b
,
(
i
+
1
):]
-=
(
i
+
1
-
prev_index
)
prev_index
=
i
+
1
inverted_mask
=
1
-
attention_mask
output_attn_mask
=
inverted_mask
.
masked_fill
(
inverted_mask
.
to
(
torch
.
bool
),
torch
.
finfo
(
inputs_embeds
.
dtype
).
min
)
if
reset_mask_flag
:
output_attn_mask
=
output_attn_mask
[:,:,
-
1
:,:]
return
output_attn_mask
,
position_ids
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
BaseModelOutputWithPast
]:
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
use_cache
=
use_cache
if
use_cache
is
not
None
else
self
.
config
.
use_cache
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
input_ids1
=
copy
.
deepcopy
(
input_ids
)
reset_mask_flag
=
False
if
past_key_values
:
input_ids
=
input_ids
[:,
-
1
:]
if
use_cache
:
reset_mask_flag
=
True
# retrieve input_ids and inputs_embeds
if
input_ids
is
not
None
and
inputs_embeds
is
not
None
:
raise
ValueError
(
"You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time"
)
elif
input_ids
is
not
None
:
batch_size
,
seq_length
=
input_ids
.
shape
elif
inputs_embeds
is
not
None
:
batch_size
,
seq_length
,
_
=
inputs_embeds
.
shape
else
:
raise
ValueError
(
"You have to specify either decoder_input_ids or decoder_inputs_embeds"
)
seq_length_with_past
=
seq_length
past_key_values_length
=
0
if
past_key_values
is
not
None
:
past_key_values_length
=
past_key_values
[
0
][
0
].
shape
[
2
]
seq_length_with_past
=
seq_length_with_past
+
past_key_values_length
if
position_ids
is
None
:
device
=
input_ids
.
device
if
input_ids
is
not
None
else
inputs_embeds
.
device
position_ids
=
torch
.
arange
(
past_key_values_length
,
seq_length
+
past_key_values_length
,
dtype
=
torch
.
long
,
device
=
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
view
(
-
1
,
seq_length
)
else
:
position_ids
=
position_ids
.
view
(
-
1
,
seq_length
).
long
()
if
inputs_embeds
is
None
:
inputs_embeds
=
self
.
embed_tokens
(
input_ids
)
if
self
.
training
or
self
.
reset_position_ids
:
attention_mask
,
_
=
self
.
_prepare_decoder_attention_mask_training
(
input_ids1
,
inputs_embeds
,
self
.
eod_token
,
reset_mask_flag
,
self
.
reset_attention_mask
,
self
.
reset_position_ids
)
else
:
if
attention_mask
is
None
:
attention_mask
=
torch
.
ones
(
(
batch_size
,
seq_length_with_past
),
dtype
=
torch
.
bool
,
device
=
inputs_embeds
.
device
)
attention_mask
=
self
.
_prepare_decoder_attention_mask
(
attention_mask
,
(
batch_size
,
seq_length
),
inputs_embeds
,
past_key_values_length
)
hidden_states
=
inputs_embeds
if
self
.
gradient_checkpointing
and
self
.
training
:
if
use_cache
:
logger
.
warning_once
(
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
)
use_cache
=
False
# decoder layers
all_hidden_states
=
()
if
output_hidden_states
else
None
all_self_attns
=
()
if
output_attentions
else
None
next_decoder_cache
=
()
if
use_cache
else
None
for
idx
,
decoder_layer
in
enumerate
(
self
.
layers
):
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
past_key_value
=
past_key_values
[
idx
]
if
past_key_values
is
not
None
else
None
if
self
.
gradient_checkpointing
and
self
.
training
:
def
create_custom_forward
(
module
):
def
custom_forward
(
*
inputs
):
# None for past_key_value
return
module
(
*
inputs
,
output_attentions
,
None
)
return
custom_forward
layer_outputs
=
torch
.
utils
.
checkpoint
.
checkpoint
(
create_custom_forward
(
decoder_layer
),
hidden_states
,
attention_mask
,
position_ids
,
None
,
)
else
:
layer_outputs
=
decoder_layer
(
hidden_states
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_value
=
past_key_value
,
output_attentions
=
output_attentions
,
use_cache
=
use_cache
,
)
hidden_states
=
layer_outputs
[
0
]
if
use_cache
:
next_decoder_cache
+=
(
layer_outputs
[
2
if
output_attentions
else
1
],)
if
output_attentions
:
all_self_attns
+=
(
layer_outputs
[
1
],)
hidden_states
=
self
.
norm
(
hidden_states
)
# add hidden states from the last decoder layer
if
output_hidden_states
:
all_hidden_states
+=
(
hidden_states
,)
next_cache
=
next_decoder_cache
if
use_cache
else
None
if
not
return_dict
:
return
tuple
(
v
for
v
in
[
hidden_states
,
next_cache
,
all_hidden_states
,
all_self_attns
]
if
v
is
not
None
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
hidden_states
,
past_key_values
=
next_cache
,
hidden_states
=
all_hidden_states
,
attentions
=
all_self_attns
,
)
class
YuanForCausalLM
(
LlamaPreTrainedModel
):
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
eod_token
=
config
.
eod_token
self
.
sep_token
=
config
.
sep_token
self
.
use_loss_mask
=
config
.
use_loss_mask
self
.
model
=
LlamaModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
hidden_size
,
config
.
vocab_size
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
def
get_output_embeddings
(
self
):
return
self
.
lm_head
def
set_output_embeddings
(
self
,
new_embeddings
):
self
.
lm_head
=
new_embeddings
def
set_decoder
(
self
,
decoder
):
self
.
model
=
decoder
def
get_decoder
(
self
):
return
self
.
model
def
get_loss_mask
(
self
,
input_ids
,
labels
,
eod_token
,
sep_token
):
micro_batch_size
,
seq_length
=
input_ids
.
size
()
loss_mask
=
torch
.
ones
(
input_ids
.
size
(),
dtype
=
torch
.
float
,
device
=
input_ids
.
device
)
position_ids
=
torch
.
arange
(
seq_length
,
dtype
=
torch
.
long
,
device
=
input_ids
.
device
)
position_ids
=
position_ids
.
unsqueeze
(
0
).
expand_as
(
input_ids
)
"""modify loss_mask to only calculate the loss of the answer (separated with [SEP])"""
for
b
in
range
(
micro_batch_size
):
eod_indexs
=
position_ids
[
b
,
input_ids
[
b
]
==
eod_token
]
sep_indexs
=
position_ids
[
b
,
input_ids
[
b
]
==
sep_token
]
if
len
(
eod_indexs
)
==
0
or
len
(
sep_indexs
)
==
0
:
loss_mask
[
b
]
=
1.0
else
:
if
eod_indexs
[
0
]
>
sep_indexs
[
0
]:
loss_mask
[
b
,
0
:
sep_indexs
[
0
]]
=
0
if
len
(
eod_indexs
)
==
len
(
sep_indexs
):
for
ii
,
eod_index
in
enumerate
(
eod_indexs
):
start_index
=
eod_index
if
ii
==
(
len
(
sep_indexs
)
-
1
):
stop_index
=
seq_length
else
:
stop_index
=
sep_indexs
[
ii
+
1
]
loss_mask
[
b
,
start_index
:
stop_index
]
=
0.0
else
:
if
len
(
eod_indexs
)
>
len
(
sep_indexs
):
loss_mask
[
b
,:]
=
1.0
else
:
for
ii
,
eod_index
in
enumerate
(
eod_indexs
):
start_index
=
eod_index
stop_index
=
sep_indexs
[
ii
+
1
]
loss_mask
[
b
,
start_index
:
stop_index
]
=
0.0
elif
eod_indexs
[
0
]
<
sep_indexs
[
0
]:
if
len
(
eod_indexs
)
==
len
(
sep_indexs
):
for
ii
,
eod_index
in
enumerate
(
eod_indexs
):
start_index
=
eod_index
stop_index
=
sep_indexs
[
ii
]
loss_mask
[
b
,
start_index
:
stop_index
]
=
0.0
else
:
if
len
(
eod_indexs
)
<
len
(
sep_indexs
):
loss_mask
[
b
,:]
=
1.0
else
:
for
ii
,
eod_index
in
enumerate
(
eod_indexs
):
start_index
=
eod_index
if
ii
>=
len
(
sep_indexs
):
stop_index
=
seq_length
else
:
stop_index
=
sep_indexs
[
ii
]
loss_mask
[
b
,
start_index
:
stop_index
]
=
0.0
loss_mask
[
input_ids
==
eod_token
]
=
1.0
return
loss_mask
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
@
replace_return_docstrings
(
output_type
=
CausalLMOutputWithPast
,
config_class
=
_CONFIG_FOR_DOC
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
CausalLMOutputWithPast
]:
r
"""
Args:
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
(masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
Returns:
Example:
```python
>>> from transformers import AutoTokenizer, LlamaForCausalLM
>>> model = LlamaForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
>>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
>>> prompt = "Hey, are you consciours? Can you talk to me?"
>>> inputs = tokenizer(prompt, return_tensors="pt")
>>> # Generate
>>> generate_ids = model.generate(inputs.input_ids, max_length=30)
>>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
"Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
```"""
output_attentions
=
output_attentions
if
output_attentions
is
not
None
else
self
.
config
.
output_attentions
output_hidden_states
=
(
output_hidden_states
if
output_hidden_states
is
not
None
else
self
.
config
.
output_hidden_states
)
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
# input_ids = input_ids.cuda()
# labels = labels.cuda()
# decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
outputs
=
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
outputs
[
0
]
logits
=
self
.
lm_head
(
hidden_states
)
loss
=
None
if
labels
is
not
None
:
if
self
.
use_loss_mask
:
loss_mask
=
self
.
get_loss_mask
(
input_ids
,
labels
,
self
.
eod_token
,
self
.
sep_token
)
# Shift so that tokens < n predict n
shift_logits
=
logits
[...,
:,
:].
contiguous
()
shift_labels
=
labels
[...,
:].
contiguous
()
# Flatten the tokens
if
self
.
use_loss_mask
:
loss_fct
=
CrossEntropyLoss
(
reduction
=
'none'
)
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
loss
=
torch
.
sum
(
loss
*
loss_mask
)
/
loss_mask
.
sum
()
else
:
loss_fct
=
CrossEntropyLoss
()
shift_logits
=
shift_logits
.
view
(
-
1
,
self
.
config
.
vocab_size
)
shift_labels
=
shift_labels
.
view
(
-
1
)
# Enable model parallelism
shift_labels
=
shift_labels
.
to
(
shift_logits
.
device
)
loss
=
loss_fct
(
shift_logits
,
shift_labels
)
if
not
return_dict
:
output
=
(
logits
,)
+
outputs
[
1
:]
return
(
loss
,)
+
output
if
loss
is
not
None
else
output
return
CausalLMOutputWithPast
(
loss
=
loss
,
logits
=
logits
,
past_key_values
=
outputs
.
past_key_values
,
hidden_states
=
hidden_states
,
attentions
=
outputs
.
attentions
,
)
def
prepare_inputs_for_generation
(
self
,
input_ids
,
past_key_values
=
None
,
attention_mask
=
None
,
inputs_embeds
=
None
,
**
kwargs
):
#if past_key_values:
# input_ids = input_ids[:, -1:]
position_ids
=
kwargs
.
get
(
"position_ids"
,
None
)
if
attention_mask
is
not
None
and
position_ids
is
None
:
# create position_ids on the fly for batch generation
position_ids
=
attention_mask
.
long
().
cumsum
(
-
1
)
-
1
position_ids
.
masked_fill_
(
attention_mask
==
0
,
1
)
if
past_key_values
:
position_ids
=
position_ids
[:,
-
1
].
unsqueeze
(
-
1
)
# if `inputs_embeds` are passed, we only want to use them in the 1st generation step
if
inputs_embeds
is
not
None
and
past_key_values
is
None
:
model_inputs
=
{
"inputs_embeds"
:
inputs_embeds
}
else
:
model_inputs
=
{
"input_ids"
:
input_ids
}
model_inputs
.
update
(
{
"position_ids"
:
position_ids
,
"past_key_values"
:
past_key_values
,
"use_cache"
:
kwargs
.
get
(
"use_cache"
),
"attention_mask"
:
attention_mask
,
}
)
return
model_inputs
@
staticmethod
def
_reorder_cache
(
past_key_values
,
beam_idx
):
reordered_past
=
()
for
layer_past
in
past_key_values
:
reordered_past
+=
(
tuple
(
past_state
.
index_select
(
0
,
beam_idx
)
for
past_state
in
layer_past
),)
return
reordered_past
@
add_start_docstrings
(
"""
The LLaMa Model transformer with a sequence classification head on top (linear layer).
[`LlamaForSequenceClassification`] uses the last token in order to do the classification, as other causal models
(e.g. GPT-2) do.
Since it does classification on the last token, it requires to know the position of the last token. If a
`pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
each row of the batch).
"""
,
LLAMA_START_DOCSTRING
,
)
class
LlamaForSequenceClassification
(
LlamaPreTrainedModel
):
_keys_to_ignore_on_load_missing
=
[
r
"lm_head.weight"
]
def
__init__
(
self
,
config
):
super
().
__init__
(
config
)
self
.
num_labels
=
config
.
num_labels
self
.
model
=
LlamaModel
(
config
)
self
.
score
=
nn
.
Linear
(
config
.
hidden_size
,
self
.
num_labels
,
bias
=
False
)
# Initialize weights and apply final processing
self
.
post_init
()
def
get_input_embeddings
(
self
):
return
self
.
model
.
embed_tokens
def
set_input_embeddings
(
self
,
value
):
self
.
model
.
embed_tokens
=
value
@
add_start_docstrings_to_model_forward
(
LLAMA_INPUTS_DOCSTRING
)
def
forward
(
self
,
input_ids
:
torch
.
LongTensor
=
None
,
attention_mask
:
Optional
[
torch
.
Tensor
]
=
None
,
position_ids
:
Optional
[
torch
.
LongTensor
]
=
None
,
past_key_values
:
Optional
[
List
[
torch
.
FloatTensor
]]
=
None
,
inputs_embeds
:
Optional
[
torch
.
FloatTensor
]
=
None
,
labels
:
Optional
[
torch
.
LongTensor
]
=
None
,
use_cache
:
Optional
[
bool
]
=
None
,
output_attentions
:
Optional
[
bool
]
=
None
,
output_hidden_states
:
Optional
[
bool
]
=
None
,
return_dict
:
Optional
[
bool
]
=
None
,
)
->
Union
[
Tuple
,
SequenceClassifierOutputWithPast
]:
r
"""
labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
`config.num_labels > 1` a classification loss is computed (Cross-Entropy).
"""
return_dict
=
return_dict
if
return_dict
is
not
None
else
self
.
config
.
use_return_dict
transformer_outputs
=
self
.
model
(
input_ids
,
attention_mask
=
attention_mask
,
position_ids
=
position_ids
,
past_key_values
=
past_key_values
,
inputs_embeds
=
inputs_embeds
,
use_cache
=
use_cache
,
output_attentions
=
output_attentions
,
output_hidden_states
=
output_hidden_states
,
return_dict
=
return_dict
,
)
hidden_states
=
transformer_outputs
[
0
]
logits
=
self
.
score
(
hidden_states
)
if
input_ids
is
not
None
:
batch_size
=
input_ids
.
shape
[
0
]
else
:
batch_size
=
inputs_embeds
.
shape
[
0
]
if
self
.
config
.
pad_token_id
is
None
and
batch_size
!=
1
:
raise
ValueError
(
"Cannot handle batch sizes > 1 if no padding token is defined."
)
if
self
.
config
.
pad_token_id
is
None
:
sequence_lengths
=
-
1
else
:
if
input_ids
is
not
None
:
sequence_lengths
=
(
torch
.
ne
(
input_ids
,
self
.
config
.
pad_token_id
).
sum
(
-
1
)
-
1
).
to
(
logits
.
device
)
else
:
sequence_lengths
=
-
1
pooled_logits
=
logits
[
torch
.
arange
(
batch_size
,
device
=
logits
.
device
),
sequence_lengths
]
loss
=
None
if
labels
is
not
None
:
labels
=
labels
.
to
(
logits
.
device
)
if
self
.
config
.
problem_type
is
None
:
if
self
.
num_labels
==
1
:
self
.
config
.
problem_type
=
"regression"
elif
self
.
num_labels
>
1
and
(
labels
.
dtype
==
torch
.
long
or
labels
.
dtype
==
torch
.
int
):
self
.
config
.
problem_type
=
"single_label_classification"
else
:
self
.
config
.
problem_type
=
"multi_label_classification"
if
self
.
config
.
problem_type
==
"regression"
:
loss_fct
=
MSELoss
()
if
self
.
num_labels
==
1
:
loss
=
loss_fct
(
pooled_logits
.
squeeze
(),
labels
.
squeeze
())
else
:
loss
=
loss_fct
(
pooled_logits
,
labels
)
elif
self
.
config
.
problem_type
==
"single_label_classification"
:
loss_fct
=
CrossEntropyLoss
()
loss
=
loss_fct
(
pooled_logits
.
view
(
-
1
,
self
.
num_labels
),
labels
.
view
(
-
1
))
elif
self
.
config
.
problem_type
==
"multi_label_classification"
:
loss_fct
=
BCEWithLogitsLoss
()
loss
=
loss_fct
(
pooled_logits
,
labels
)
if
not
return_dict
:
output
=
(
pooled_logits
,)
+
transformer_outputs
[
1
:]
return
((
loss
,)
+
output
)
if
loss
is
not
None
else
output
return
SequenceClassifierOutputWithPast
(
loss
=
loss
,
logits
=
pooled_logits
,
past_key_values
=
transformer_outputs
.
past_key_values
,
hidden_states
=
transformer_outputs
.
hidden_states
,
attentions
=
transformer_outputs
.
attentions
,
)
megatron/model/yuan_model.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2023, NVIDIA CORPORATION. All rights reserved.
"""GPT-2 model."""
import
torch
from
megatron
import
get_args
from
megatron.core
import
tensor_parallel
from
.module
import
MegatronModule
from
.enums
import
AttnMaskType
from
.language_model
import
parallel_lm_logits
from
.language_model
import
get_language_model
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
parallel_output
,
fp16_lm_cross_entropy
):
# Output. Format [s b h]
output
=
parallel_lm_logits
(
lm_output
,
logit_weights
,
parallel_output
)
if
labels
is
None
:
# [s b h] => [b s h]
return
output
.
transpose
(
0
,
1
).
contiguous
()
else
:
# [b s] => [s b]
labels
=
labels
.
transpose
(
0
,
1
).
contiguous
()
if
fp16_lm_cross_entropy
:
assert
output
.
dtype
==
torch
.
half
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
loss
=
tensor_parallel
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
class
YuanModel
(
MegatronModule
):
"""GPT-2 Language model."""
def
__init__
(
self
,
config
,
num_tokentypes
=
0
,
parallel_output
=
True
,
pre_process
=
True
,
post_process
=
True
):
args
=
get_args
()
super
().
__init__
(
config
=
config
,
share_embeddings_and_output_weights
=
not
args
.
untie_embeddings_and_output_weights
)
self
.
parallel_output
=
parallel_output
self
.
pre_process
=
pre_process
self
.
post_process
=
post_process
self
.
fp16_lm_cross_entropy
=
args
.
fp16_lm_cross_entropy
self
.
untie_embeddings_and_output_weights
=
args
.
untie_embeddings_and_output_weights
self
.
language_model
,
self
.
_language_model_key
=
get_language_model
(
config
=
config
,
num_tokentypes
=
num_tokentypes
,
add_pooler
=
False
,
encoder_attn_mask_type
=
AttnMaskType
.
causal
,
pre_process
=
self
.
pre_process
,
post_process
=
self
.
post_process
)
if
not
args
.
untie_embeddings_and_output_weights
:
self
.
initialize_word_embeddings
()
def
set_input_tensor
(
self
,
input_tensor
):
"""See megatron.model.transformer.set_input_tensor()"""
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
retriever_input_ids
=
None
,
retriever_position_ids
=
None
,
retriever_attn_mask
=
None
,
labels
=
None
,
tokentype_ids
=
None
,
inference_params
=
None
):
lm_output
=
self
.
language_model
(
input_ids
,
position_ids
,
attention_mask
,
retriever_input_ids
=
retriever_input_ids
,
retriever_position_ids
=
retriever_position_ids
,
retriever_attn_mask
=
retriever_attn_mask
,
inference_params
=
inference_params
)
if
self
.
post_process
:
return
post_language_model_processing
(
lm_output
,
labels
,
self
.
language_model
.
output_layer
.
weight
if
self
.
untie_embeddings_and_output_weights
else
self
.
shared_embedding_or_output_weight
(),
self
.
parallel_output
,
self
.
fp16_lm_cross_entropy
)
else
:
return
lm_output
def
state_dict_for_save_checkpoint
(
self
,
prefix
=
''
,
keep_vars
=
False
):
state_dict_
=
{}
state_dict_
[
self
.
_language_model_key
]
\
=
self
.
language_model
.
state_dict_for_save_checkpoint
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
# Save word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
state_dict_
[
self
.
_word_embeddings_for_head_key
]
\
=
self
.
word_embeddings
.
state_dict
(
prefix
=
prefix
,
keep_vars
=
keep_vars
)
return
state_dict_
def
load_state_dict
(
self
,
state_dict
,
strict
=
True
):
"""Customized load."""
# Load word_embeddings.
if
self
.
post_process
and
not
self
.
pre_process
and
not
self
.
untie_embeddings_and_output_weights
:
self
.
word_embeddings
.
load_state_dict
(
state_dict
[
self
.
_word_embeddings_for_head_key
],
strict
=
strict
)
if
self
.
_language_model_key
in
state_dict
:
state_dict
=
state_dict
[
self
.
_language_model_key
]
self
.
language_model
.
load_state_dict
(
state_dict
,
strict
=
strict
)
megatron/mpu/tests/__init__.py
0 → 100644
View file @
d3dd8642
megatron/mpu/tests/commons.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
import
argparse
import
os
import
random
import
numpy
import
torch
import
mpu
class
IdentityLayer
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
size
,
scale
=
1.0
):
super
(
IdentityLayer
,
self
).
__init__
()
self
.
weight
=
torch
.
nn
.
Parameter
(
scale
*
torch
.
randn
(
size
))
def
forward
(
self
):
return
self
.
weight
def
set_random_seed
(
seed
):
"""Set random seed for reproducability."""
random
.
seed
(
seed
)
numpy
.
random
.
seed
(
seed
)
torch
.
manual_seed
(
seed
)
mpu
.
model_parallel_cuda_manual_seed
(
seed
)
def
initialize_distributed
(
backend
=
'nccl'
):
"""Initialize torch.distributed."""
# Get local rank in case it is provided.
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--local_rank'
,
type
=
int
,
default
=
None
,
help
=
'local rank passed from distributed launcher'
)
args
=
parser
.
parse_args
()
local_rank
=
args
.
local_rank
# Get rank and world size.
rank
=
int
(
os
.
getenv
(
'RANK'
,
'0'
))
world_size
=
int
(
os
.
getenv
(
"WORLD_SIZE"
,
'1'
))
print
(
'> initializing torch.distributed with local rank: {}, '
'rank: {}, world size: {}'
.
format
(
local_rank
,
rank
,
world_size
))
# Set the device id.
device
=
rank
%
torch
.
cuda
.
device_count
()
if
local_rank
is
not
None
:
device
=
local_rank
torch
.
cuda
.
set_device
(
device
)
# Call the init process.
init_method
=
'tcp://'
master_ip
=
os
.
getenv
(
'MASTER_ADDR'
,
'localhost'
)
master_port
=
os
.
getenv
(
'MASTER_PORT'
,
'6000'
)
init_method
+=
master_ip
+
':'
+
master_port
torch
.
distributed
.
init_process_group
(
backend
=
backend
,
world_size
=
world_size
,
rank
=
rank
,
init_method
=
init_method
)
def
print_separator
(
message
):
torch
.
distributed
.
barrier
()
filler_len
=
(
78
-
len
(
message
))
//
2
filler
=
'-'
*
filler_len
string
=
'
\n
'
+
filler
+
' {} '
.
format
(
message
)
+
filler
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
string
,
flush
=
True
)
torch
.
distributed
.
barrier
()
megatron/mpu/tests/test_cross_entropy.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
commons
import
set_random_seed
from
commons
import
IdentityLayer
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu.cross_entropy
import
vocab_parallel_cross_entropy
import
mpu
import
torch.nn.functional
as
F
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
def
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
set_random_seed
(
seed
)
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
F
.
cross_entropy
(
logits
.
view
(
-
1
,
logits
.
size
()[
-
1
]),
target
.
view
(
-
1
),
reduction
=
'none'
).
view_as
(
target
).
mean
()
loss
.
backward
()
return
loss
,
identity
.
weight
.
grad
def
mpu_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
):
set_random_seed
(
seed
)
identity
=
IdentityLayer
((
batch_size
,
seq_length
,
vocab_size
),
scale
=
logits_scale
).
cuda
()
logits
=
identity
()
logits_parallel
=
mpu
.
scatter_to_tensor_model_parallel_region
(
logits
)
target
=
torch
.
cuda
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
)
loss
=
vocab_parallel_cross_entropy
(
logits_parallel
,
target
).
mean
()
loss
.
backward
()
return
loss
,
identity
.
weight
.
grad
def
test_cross_entropy
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing cross entropy with model parallel size {} ...'
.
format
(
tensor_model_parallel_size
))
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
batch_size
=
13
seq_length
=
17
vocab_size_per_partition
=
11
logits_scale
=
1000.0
vocab_size
=
vocab_size_per_partition
*
tensor_model_parallel_size
seed
=
1234
loss_torch
,
grad_torch
=
torch_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
)
loss_mpu
,
grad_mpu
=
mpu_cross_entropy
(
batch_size
,
seq_length
,
vocab_size
,
logits_scale
,
seed
)
error
=
loss_torch
.
sub_
(
loss_mpu
).
abs
().
max
()
print
(
' max error in loss on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
grad_torch
.
sub_
(
grad_mpu
).
abs
().
max
()
print
(
' max error in grad on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_tensor_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
if
__name__
==
'__main__'
:
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test cross entropy'
)
test_cross_entropy
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
megatron/mpu/tests/test_data.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
commons
import
print_separator
from
commons
import
initialize_distributed
from
mpu
import
data
as
data_utils
import
mpu
import
torch
import
functools
import
operator
import
sys
sys
.
path
.
append
(
"../.."
)
def
test_broadcast_data
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing broadcast_data with model parallel size {} ...'
.
format
(
tensor_model_parallel_size
))
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
torch
.
manual_seed
(
1234
+
mpu
.
get_data_parallel_rank
())
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
key_size_t
=
{
'key1'
:
[
7
,
11
],
'key2'
:
[
8
,
2
,
1
],
'key3'
:
[
13
],
'key4'
:
[
5
,
1
,
2
],
'key5'
:
[
5
,
12
]}
keys
=
list
(
key_size_t
.
keys
())
data
=
{}
data_t
=
{}
for
key
in
key_size_t
:
data
[
key
]
=
torch
.
LongTensor
(
size
=
key_size_t
[
key
]).
random_
(
0
,
1000
)
data_t
[
key
]
=
data
[
key
].
clone
()
data
[
'keyX'
]
=
torch
.
FloatTensor
(
size
=
(
5
,
)).
random_
(
0
,
1000
)
data_t
[
'keyX'
]
=
data
[
'keyX'
].
clone
()
if
mpu
.
get_tensor_model_parallel_rank
()
!=
0
:
data
=
None
data_utils
.
_check_data_types
(
keys
,
data_t
,
torch
.
int64
)
key_size
,
key_numel
,
\
total_numel
=
data_utils
.
_build_key_size_numel_dictionaries
(
keys
,
data
)
for
key
in
keys
:
assert
key_size
[
key
]
==
key_size_t
[
key
]
total_numel_t
=
0
for
key
in
keys
:
target_size
=
functools
.
reduce
(
operator
.
mul
,
key_size_t
[
key
],
1
)
assert
key_numel
[
key
]
==
target_size
total_numel_t
+=
target_size
assert
total_numel
==
total_numel_t
data_b
=
data_utils
.
broadcast_data
(
keys
,
data
,
torch
.
int64
)
for
key
in
keys
:
tensor
=
data_t
[
key
].
cuda
()
assert
data_b
[
key
].
sub
(
tensor
).
abs
().
max
()
==
0
# Reset groups
mpu
.
destroy_tensor_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
if
__name__
==
'__main__'
:
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test test broadcast data'
)
test_broadcast_data
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
megatron/mpu/tests/test_initialize.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
def
test_initialize_model_parallel
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing initialize_model_parallel with size {} ...'
.
format
(
tensor_model_parallel_size
))
tensor_model_parallel_size_
=
min
(
tensor_model_parallel_size
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size_
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks.
def
check
(
group
,
world_size
,
rank
):
assert
world_size
==
torch
.
distributed
.
get_world_size
(
group
=
group
)
assert
rank
==
torch
.
distributed
.
get_rank
(
group
=
group
)
# Model parallel.
world_size
=
tensor_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
%
tensor_model_parallel_size_
assert
world_size
==
mpu
.
get_tensor_model_parallel_world_size
()
assert
rank
==
mpu
.
get_tensor_model_parallel_rank
()
check
(
mpu
.
get_tensor_model_parallel_group
(),
world_size
,
rank
)
# Data parallel.
world_size
=
torch
.
distributed
.
get_world_size
()
//
tensor_model_parallel_size_
rank
=
torch
.
distributed
.
get_rank
()
//
tensor_model_parallel_size
assert
world_size
==
mpu
.
get_data_parallel_world_size
()
assert
rank
==
mpu
.
get_data_parallel_rank
()
check
(
mpu
.
get_data_parallel_group
(),
world_size
,
rank
)
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
def
test_get_tensor_model_parallel_src_rank
(
tensor_model_parallel_size_
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing get_tensor_model_parallel_src_rank with size {} ...'
.
format
(
tensor_model_parallel_size_
))
tensor_model_parallel_size
=
min
(
tensor_model_parallel_size_
,
torch
.
distributed
.
get_world_size
())
assert
not
mpu
.
model_parallel_is_initialized
()
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
assert
mpu
.
model_parallel_is_initialized
()
# Checks
src_rank
=
torch
.
distributed
.
get_rank
()
-
mpu
.
get_tensor_model_parallel_rank
()
assert
mpu
.
get_tensor_model_parallel_src_rank
()
==
src_rank
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
if
__name__
==
'__main__'
:
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test initialize model parallel'
)
test_initialize_model_parallel
(
tensor_model_parallel_size
)
print_separator
(
'test model parallel source rank'
)
test_get_tensor_model_parallel_src_rank
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
megatron/mpu/tests/test_layers.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
mpu
import
layers
from
commons
import
set_random_seed
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
from
torch.nn.parameter
import
Parameter
import
torch.nn.init
as
init
import
torch
import
random
import
sys
sys
.
path
.
append
(
"../.."
)
def
test_parallel_embedding
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing parallel embedding with model parallel size {} ...'
.
format
(
tensor_model_parallel_size
))
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
batch_size
=
17
seq_length
=
23
vocab_size
=
48
hidden_size
=
16
seed
=
1236
set_random_seed
(
123
)
input_data
=
torch
.
LongTensor
(
size
=
(
batch_size
,
seq_length
)).
random_
(
0
,
vocab_size
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
seq_length
,
hidden_size
]).
cuda
()
set_random_seed
(
seed
)
embedding_original
=
torch
.
nn
.
Embedding
(
vocab_size
,
hidden_size
).
cuda
()
output
=
embedding_original
(
input_data
)
loss_original
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_original
.
backward
()
set_random_seed
(
seed
)
embedding_parallel
=
layers
.
ParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_parallel
(
input_data
)
loss_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_parallel
.
backward
()
set_random_seed
(
seed
)
embedding_vocab_parallel
=
layers
.
VocabParallelEmbedding
(
vocab_size
,
hidden_size
,
init_method
=
init
.
normal_
).
cuda
()
output
=
embedding_vocab_parallel
(
input_data
)
loss_vocab_parallel
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
loss_vocab_parallel
.
backward
()
torch
.
distributed
.
barrier
()
error
=
loss_parallel
.
sub
(
loss_original
).
abs
()
print
(
' error in loss (parallel) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
torch
.
distributed
.
barrier
()
error
=
loss_vocab_parallel
.
sub
(
loss_original
).
abs
()
print
(
' error in loss (vocab parallel) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
hidden_size
//
tensor_model_parallel_size
,
1
)[
mpu
.
get_tensor_model_parallel_rank
()]
error
=
embedding_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
' error in grad (parallel) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
weight_grad_orig
=
torch
.
split
(
embedding_original
.
weight
.
grad
,
vocab_size
//
tensor_model_parallel_size
,
0
)[
mpu
.
get_tensor_model_parallel_rank
()]
error
=
embedding_vocab_parallel
.
weight
.
grad
.
sub
(
weight_grad_orig
).
abs
().
max
()
print
(
' error in grad (vocab parallel) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-12
,
'error: {}'
.
format
(
error
)
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
def
test_initialize_affine_weight
(
tensor_model_parallel_size
):
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing initialize_affine_weight with model parallel '
'size: {}'
.
format
(
tensor_model_parallel_size
))
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
seed
=
12345
input_size_coeff
=
13
input_size
=
input_size_coeff
*
tensor_model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
tensor_model_parallel_size
# ---------------
# Column parallel
# ---------------
weight
=
torch
.
empty
(
output_size_coeff
,
input_size
)
set_random_seed
(
seed
)
layers
.
_initialize_affine_weight
(
weight
,
output_size
,
input_size
,
output_size_coeff
,
0
,
torch
.
nn
.
init
.
normal_
)
# Target.
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_tensor_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
# Compare.
error
=
weight
.
sub
(
my_weight
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' column parallel max error (should be zero) on global rank '
'{}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# ------------
# Row parallel
# ------------
weight
=
torch
.
empty
(
output_size
,
input_size_coeff
)
set_random_seed
(
seed
)
mpu
.
layers
.
_initialize_affine_weight
(
weight
,
output_size
,
input_size
,
input_size_coeff
,
1
,
torch
.
nn
.
init
.
normal_
)
# Target.
set_random_seed
(
seed
)
master_weight
=
torch
.
empty
(
output_size
,
input_size
)
torch
.
nn
.
init
.
normal_
(
master_weight
)
rank
=
mpu
.
get_tensor_model_parallel_rank
()
my_weight
=
torch
.
split
(
master_weight
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
# Compare.
error
=
weight
.
sub
(
my_weight
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' row parallel max error (should be zero) on global rank '
'{}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
class
IdentityLayer2D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
):
super
(
IdentityLayer2D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
def
test_column_parallel_linear
(
tensor_model_parallel_size
):
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ColumnParallelLinear with model parallel '
'size: {}'
.
format
(
tensor_model_parallel_size
))
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
tensor_model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
tensor_model_parallel_size
batch_size
=
7
# Network
identity_layer
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
linear_layer
=
mpu
.
ColumnParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
linear_layer
(
input_
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
# Values.
dLdY
=
loss_weight
X
=
identity_layer
.
weight
A
=
linear_layer
.
master_weight
.
cuda
()
dLdA
=
torch
.
matmul
(
dLdY
.
t
(),
X
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_tensor_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' error in dLdA on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
my_dLdb
=
torch
.
split
(
dLdb
,
output_size_coeff
,
dim
=
0
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdb
.
sub
(
linear_layer
.
bias
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' error in dLdb on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdX
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' error in dLdX on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
def
test_row_parallel_linear
(
tensor_model_parallel_size
):
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing RowParallelLinear with model parallel '
'size: {}'
.
format
(
tensor_model_parallel_size
))
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
input_size_coeff
=
13
input_size
=
input_size_coeff
*
tensor_model_parallel_size
output_size_coeff
=
17
output_size
=
output_size_coeff
*
tensor_model_parallel_size
batch_size
=
7
# Network
identity_layer
=
IdentityLayer2D
(
batch_size
,
input_size
).
cuda
()
linear_layer
=
mpu
.
RowParallelLinear
(
input_size
,
output_size
,
keep_master_weight_for_test
=
True
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
output_size
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
linear_layer
(
input_
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
# Values.
dLdY
=
loss_weight
X
=
identity_layer
.
weight
A
=
linear_layer
.
master_weight
.
cuda
()
dLdA
=
torch
.
matmul
(
dLdY
.
t
(),
X
)
dLdb
=
torch
.
matmul
(
torch
.
ones
(
batch_size
,
1
).
cuda
().
t
(),
dLdY
).
view
(
-
1
)
dLdX
=
torch
.
matmul
(
dLdY
,
A
)
rank
=
mpu
.
get_tensor_model_parallel_rank
()
my_dLdA
=
torch
.
split
(
dLdA
,
input_size_coeff
,
dim
=
1
)[
rank
].
contiguous
().
clone
()
error
=
my_dLdA
.
sub
(
linear_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' error in dLdA on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdb
.
sub
(
linear_layer
.
bias
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' error in dLdb on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
error
=
dLdX
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' error in dLdX on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
class
IdentityLayer3D
(
torch
.
nn
.
Module
):
def
__init__
(
self
,
m
,
n
,
k
):
super
(
IdentityLayer3D
,
self
).
__init__
()
self
.
weight
=
Parameter
(
torch
.
Tensor
(
m
,
n
,
k
))
torch
.
nn
.
init
.
xavier_normal_
(
self
.
weight
)
def
forward
(
self
):
return
self
.
weight
def
parallel_self_attention
(
tensor_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
):
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
# Network
identity_layer
=
IdentityLayer3D
(
batch_size
,
sequence_length
,
hidden_size
).
cuda
()
attention_layer
=
mpu
.
BertParallelSelfAttention
(
hidden_size
,
num_att_heads
,
dropout_prob
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
sequence_length
,
hidden_size
]).
cuda
()
attention_mask
=
torch
.
randn
([
batch_size
,
1
,
1
,
sequence_length
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
attention_layer
(
input_
,
attention_mask
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
rank
=
mpu
.
get_tensor_model_parallel_rank
()
mpu
.
destroy_model_parallel
()
return
rank
,
hidden_size
,
tensor_model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
def
test_parallel_self_attention
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ParallelSelfAttention with model parallel '
'size: {}'
.
format
(
tensor_model_parallel_size
))
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
dropout_prob
=
0.0
# has to be zero
batch_size
=
5
sequence_length
=
13
rank_1
,
hideen_size_1
,
tensor_model_parallel_size_1
,
loss_1
,
\
attention_layer_1
,
identity_layer_1
=
parallel_self_attention
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
tensor_model_parallel_size
,
loss
,
\
attention_layer
,
identity_layer
=
parallel_self_attention
(
tensor_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
dropout_prob
,
batch_size
,
sequence_length
)
assert
hideen_size_1
==
hidden_size
error
=
loss_1
.
sub
(
loss
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' loss error on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
5.0e-6
my_lin_grad_list
=
torch
.
split
(
attention_layer_1
.
query_key_value
.
weight
.
grad
,
hidden_size
//
tensor_model_parallel_size
,
0
)[
rank
::
tensor_model_parallel_size
]
my_lin_grad
=
torch
.
cat
(
my_lin_grad_list
,
dim
=
0
)
error
=
my_lin_grad
.
sub
(
attention_layer
.
query_key_value
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' weight gradient error on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
5.0e-6
error
=
identity_layer_1
.
weight
.
grad
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' input gradient error on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
5.0e-6
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
def
parallel_transformer
(
tensor_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
):
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
seed
=
12345
set_random_seed
(
seed
)
num_att_heads
=
num_att_heads_per_partition
*
\
torch
.
distributed
.
get_world_size
()
hidden_size
=
hidden_size_per_att_head
*
num_att_heads
intermediate_size
=
4
*
hidden_size
# Network
identity_layer
=
IdentityLayer3D
(
batch_size
,
sequence_length
,
hidden_size
).
cuda
()
transformer_layer
=
mpu
.
BertParallelTransformerLayer
(
hidden_size
,
intermediate_size
,
num_att_heads
,
0.0
,
0.0
,
torch
.
nn
.
functional
.
relu
,
1.0e-5
).
cuda
()
loss_weight
=
torch
.
randn
([
batch_size
,
sequence_length
,
hidden_size
]).
cuda
()
attention_mask
=
torch
.
randn
([
batch_size
,
1
,
1
,
sequence_length
]).
cuda
()
# Forward
input_
=
identity_layer
()
output
=
transformer_layer
(
input_
,
attention_mask
)
loss
=
torch
.
mul
(
output
,
loss_weight
).
sum
()
# Backward
loss
.
backward
()
rank
=
mpu
.
get_tensor_model_parallel_rank
()
mpu
.
destroy_model_parallel
()
return
rank
,
hidden_size
,
tensor_model_parallel_size
,
loss
,
\
transformer_layer
,
identity_layer
def
test_parallel_transformer_layer
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing ParallelTransformerLayer with model parallel '
'size: {}'
.
format
(
tensor_model_parallel_size
))
num_att_heads_per_partition
=
3
hidden_size_per_att_head
=
7
batch_size
=
5
sequence_length
=
13
rank_1
,
hidden_size_1
,
tensor_model_parallel_size_1
,
loss_1
,
\
transformer_layer_1
,
identity_layer_1
=
parallel_transformer
(
1
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
rank
,
hidden_size
,
tensor_model_parallel_size
,
loss
,
\
transformer_layer
,
identity_layer
=
parallel_transformer
(
tensor_model_parallel_size
,
num_att_heads_per_partition
,
hidden_size_per_att_head
,
batch_size
,
sequence_length
)
error
=
loss_1
.
sub
(
loss
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' loss error on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
5.0e-5
,
'error: {}'
.
format
(
error
)
error
=
identity_layer_1
.
weight
.
grad
.
sub
(
identity_layer
.
weight
.
grad
).
abs
().
max
()
torch
.
distributed
.
barrier
()
print
(
' input gradient error on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
5.0e-5
,
'error: {}'
.
format
(
error
)
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
' >> passed the test :-)'
)
if
__name__
==
'__main__'
:
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
print_separator
(
'test initialize affine weight'
)
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
test_initialize_affine_weight
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test parallel embedding'
)
test_parallel_embedding
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
print_separator
(
'test column-parallel linear'
)
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
test_column_parallel_linear
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
print_separator
(
'test row-parallel linear'
)
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
test_row_parallel_linear
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
print_separator
(
'test parallel self-attention'
)
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
test_parallel_self_attention
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
print_separator
(
'test parallel transformer'
)
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
test_parallel_transformer_layer
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
megatron/mpu/tests/test_random.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
commons
import
print_separator
from
commons
import
initialize_distributed
import
mpu
import
torch
import
sys
sys
.
path
.
append
(
"../.."
)
def
test_set_cuda_rng_state
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing set_rng_state with size {} ...'
.
format
(
tensor_model_parallel_size
))
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
size
=
123
seed
=
1234
torch
.
cuda
.
manual_seed
(
1234
)
tensor
=
torch
.
cuda
.
FloatTensor
(
size
)
# Get the state
rng_state
=
torch
.
cuda
.
get_rng_state
()
rng_state_copy
=
rng_state
.
clone
()
# Do some stuff.
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
result_1
=
tensor
.
clone
()
assert
rng_state
.
sub
(
rng_state_copy
).
max
()
==
0
assert
torch
.
cuda
.
get_rng_state
().
sub
(
rng_state_copy
).
max
()
>
0
# State should be different.
new_rng_state
=
torch
.
cuda
.
get_rng_state
()
max_diff
=
new_rng_state
.
sub
(
rng_state
).
max
()
print
(
' max diff in rng state (should be non-zero) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
max_diff
))
assert
max_diff
>
0
# Reset the rng state and do the same stuff.
mpu
.
random
.
_set_cuda_rng_state
(
rng_state
)
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
mpu
.
random
.
_set_cuda_rng_state
(
rng_state
)
for
_
in
range
(
5
):
torch
.
randn
(
size
,
out
=
tensor
)
result_2
=
tensor
.
clone
()
# Results should be the same
error
=
result_2
.
sub
(
result_1
).
abs
().
max
()
print
(
' max error in generated tensors (should be zero) on '
'global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Input state should have remained intact.
error
=
rng_state
.
sub
(
rng_state_copy
).
max
()
print
(
' max error in rng state (should be zero) on global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
==
0
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
def
test_cuda_rng_tracker
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing cuda rng tracker with size {} ...'
.
format
(
tensor_model_parallel_size
))
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
seed_1
=
1234
seed_2
=
4321
size
=
[
12
,
21
]
tensor
=
torch
.
cuda
.
FloatTensor
(
size
)
# Set to seed_1 and generate two tensors.
torch
.
cuda
.
manual_seed
(
seed_1
)
torch
.
randn
(
size
,
out
=
tensor
)
target_11
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
target_12
=
tensor
.
clone
()
# Set to seed_2 and generate two tensors.
torch
.
cuda
.
manual_seed
(
seed_2
)
torch
.
randn
(
size
,
out
=
tensor
)
target_21
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
target_22
=
tensor
.
clone
()
# Now if we interleave seed_1 and seed_2,
# we should still get the same tensors
torch
.
cuda
.
manual_seed
(
seed_1
)
mpu
.
get_cuda_rng_tracker
().
add
(
'test'
,
seed_2
)
torch
.
randn
(
size
,
out
=
tensor
)
result_11
=
tensor
.
clone
()
with
mpu
.
get_cuda_rng_tracker
().
fork
(
'test'
):
torch
.
randn
(
size
,
out
=
tensor
)
result_21
=
tensor
.
clone
()
torch
.
randn
(
size
,
out
=
tensor
)
result_12
=
tensor
.
clone
()
with
mpu
.
get_cuda_rng_tracker
().
fork
(
'test'
):
torch
.
randn
(
size
,
out
=
tensor
)
result_22
=
tensor
.
clone
()
diff
=
result_11
.
sub
(
result_21
).
abs
().
max
()
diff
=
min
(
diff
,
result_12
.
sub
(
result_22
).
abs
().
max
())
print
(
' max diff in generated tensors (should be non-zero) on '
'global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
diff
))
assert
diff
>
1.0e-6
error
=
max
(
result_11
.
sub
(
target_11
).
abs
().
max
(),
result_12
.
sub
(
target_12
).
abs
().
max
())
error
=
max
(
error
,
result_21
.
sub
(
target_21
).
abs
().
max
())
error
=
max
(
error
,
result_22
.
sub
(
target_22
).
abs
().
max
())
print
(
' max error in generated tensors (should be zero) on '
'global rank {}: {}'
.
format
(
torch
.
distributed
.
get_rank
(),
error
))
assert
error
<
1.0e-6
# Reset the tracker
mpu
.
get_cuda_rng_tracker
().
reset
()
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
def
test_model_parallel_cuda_manual_seed
(
tensor_model_parallel_size
):
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'> testing model parallel cuda manual seed with size {} ...'
.
format
(
tensor_model_parallel_size
))
mpu
.
initialize_model_parallel
(
tensor_model_parallel_size
)
tensor_model_parallel_size
=
mpu
.
get_tensor_model_parallel_world_size
()
mpu
.
model_parallel_cuda_manual_seed
(
12345
)
assert
torch
.
cuda
.
initial_seed
()
==
12345
with
mpu
.
get_cuda_rng_tracker
().
fork
():
assert
torch
.
cuda
.
initial_seed
()
==
(
12345
+
2718
+
mpu
.
get_tensor_model_parallel_rank
())
# Reset the tracker
mpu
.
get_cuda_rng_tracker
().
reset
()
# Reset groups
mpu
.
destroy_model_parallel
()
torch
.
distributed
.
barrier
()
if
torch
.
distributed
.
get_rank
()
==
0
:
print
(
'>> passed the test :-)'
)
if
__name__
==
'__main__'
:
initialize_distributed
()
world_size
=
torch
.
distributed
.
get_world_size
()
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test set rng state'
)
test_set_cuda_rng_state
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test cuda rng tracker'
)
test_cuda_rng_tracker
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
tensor_model_parallel_size
=
1
while
tensor_model_parallel_size
<=
world_size
:
print_separator
(
'test model parallel cuda manual seed'
)
test_model_parallel_cuda_manual_seed
(
tensor_model_parallel_size
)
tensor_model_parallel_size
*=
2
megatron/optimizer/__init__.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
from
apex.optimizers
import
FusedAdam
as
Adam
from
apex.optimizers
import
FusedSGD
as
SGD
from
megatron
import
get_args
from
.distrib_optimizer
import
DistributedOptimizer
from
.grad_scaler
import
ConstantGradScaler
,
DynamicGradScaler
from
.optimizer
import
Float16OptimizerWithFloat16Params
,
FP32Optimizer
def
get_param_groups
(
modules
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
):
"""creates param groups based on weight decay condition (regularized vs non regularized)
and learning rate scale condition (args.lr vs lr_mult * args.lr)
scale_lr_cond is used during finetuning where head of the network requires a scaled
version of the base learning rate.
"""
wd_no_scale_lr
=
[]
wd_scale_lr
=
[]
no_wd_no_scale_lr
=
[]
no_wd_scale_lr
=
[]
for
module
in
modules
:
for
name
,
param
in
module
.
named_parameters
():
if
not
param
.
requires_grad
:
continue
if
no_weight_decay_cond
is
not
None
:
no_wd
=
no_weight_decay_cond
(
name
,
param
)
else
:
# do not regularize biases nor Norm parameters
no_wd
=
name
.
endswith
(
".bias"
)
or
len
(
param
.
shape
)
==
1
if
scale_lr_cond
is
not
None
:
scale_lr
=
scale_lr_cond
(
name
,
param
)
else
:
scale_lr
=
False
if
not
no_wd
and
not
scale_lr
:
wd_no_scale_lr
.
append
(
param
)
elif
not
no_wd
and
scale_lr
:
wd_scale_lr
.
append
(
param
)
elif
no_wd
and
not
scale_lr
:
no_wd_no_scale_lr
.
append
(
param
)
else
:
no_wd_scale_lr
.
append
(
param
)
param_groups
=
[]
if
len
(
wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
wd_no_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
1.0
})
if
len
(
wd_scale_lr
):
param_groups
.
append
({
'params'
:
wd_scale_lr
,
'wd_mult'
:
1.0
,
'lr_mult'
:
lr_mult
})
if
len
(
no_wd_no_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_no_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
1.0
})
if
len
(
no_wd_scale_lr
):
param_groups
.
append
({
'params'
:
no_wd_scale_lr
,
'wd_mult'
:
0.0
,
'lr_mult'
:
lr_mult
})
return
param_groups
def
get_megatron_optimizer
(
model
,
no_weight_decay_cond
=
None
,
scale_lr_cond
=
None
,
lr_mult
=
1.0
):
args
=
get_args
()
# Base optimizer.
param_groups
=
get_param_groups
(
model
,
no_weight_decay_cond
,
scale_lr_cond
,
lr_mult
)
if
args
.
optimizer
==
'adam'
:
optimizer
=
Adam
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
betas
=
(
args
.
adam_beta1
,
args
.
adam_beta2
),
eps
=
args
.
adam_eps
)
elif
args
.
optimizer
==
'sgd'
:
optimizer
=
SGD
(
param_groups
,
lr
=
args
.
lr
,
weight_decay
=
args
.
weight_decay
,
momentum
=
args
.
sgd_momentum
)
else
:
raise
Exception
(
'{} optimizer is not supported.'
.
format
(
args
.
optimizer
))
# Determine whether the params have main-grad field.
params_have_main_grad
=
False
if
args
.
DDP_impl
==
'local'
:
params_have_main_grad
=
True
# Mixed precision optimizer.
# - Note: both the Float16Optimizer and the DistributedOptimizer inherit
# from the MixedPrecisionOptimizer, which manages any optimizer where
# the model params and main params are distinct.
if
args
.
fp16
or
args
.
bf16
or
args
.
use_distributed_optimizer
:
# Grad scaler:
# if loss-scale is provided, instantiate the constant scaler.
# if we are using fp16 and loss-scale is not present, use a
# dynamic scaler.
# otherwise we are running in bf16 with no loss-scale so
# leave it as None.
grad_scaler
=
None
# Constant loss scale.
if
args
.
loss_scale
:
grad_scaler
=
ConstantGradScaler
(
args
.
loss_scale
)
# Dynamic loss scale.
else
:
if
args
.
fp16
:
grad_scaler
=
DynamicGradScaler
(
initial_scale
=
args
.
initial_loss_scale
,
min_scale
=
args
.
min_loss_scale
,
growth_factor
=
2.0
,
backoff_factor
=
0.5
,
growth_interval
=
args
.
loss_scale_window
,
hysteresis
=
args
.
hysteresis
)
# Megatron optimizer.
opt_ty
=
DistributedOptimizer
\
if
args
.
use_distributed_optimizer
else
\
Float16OptimizerWithFloat16Params
return
opt_ty
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
args
.
fp16
,
args
.
bf16
,
args
.
params_dtype
,
grad_scaler
,
model
)
# FP32.
return
FP32Optimizer
(
optimizer
,
args
.
clip_grad
,
args
.
log_num_zeros_in_grad
,
params_have_main_grad
,
args
.
use_contiguous_buffers_in_local_ddp
,
model
)
megatron/optimizer/clip_grads.py
0 → 100644
View file @
d3dd8642
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
"""Gradient clipping."""
import
torch
from
torch
import
inf
from
apex.multi_tensor_apply
import
multi_tensor_applier
import
amp_C
from
megatron.model.module
import
param_is_not_shared
from
megatron.core.tensor_parallel
import
param_is_not_tensor_parallel_duplicate
def
clip_grad_norm_fp32
(
parameters
,
grads_for_norm
,
max_norm
,
norm_type
=
2
,
model_parallel_group
=
None
):
"""Clips gradient norm of an iterable of parameters whose gradients
are in fp32.
This is adapted from torch.nn.utils.clip_grad.clip_grad_norm_ and
added functionality to handle model parallel parameters. Note that
the gradients are modified in place.
Arguments:
parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
single Tensor that will have gradients normalized
grads_for_norm (Iterable[Tensor]): an iterable of Tensors or a single
Tensor that will be used for calculating the grad norm.
max_norm (float or int): max norm of the gradients
norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
infinity norm.
model_parallel_group (group): given the nature of the distributed
optimizer, this is passed as an argument.
Returns:
Total norm of the parameters (viewed as a single vector).
"""
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
if
isinstance
(
grads_for_norm
,
torch
.
Tensor
):
grads_for_norm
=
[
grads_for_norm
]
# Grads.
grads
=
[]
for
param
in
parameters
:
if
param
.
grad
is
not
None
:
assert
param
.
grad
.
type
()
==
'torch.cuda.FloatTensor'
grads
.
append
(
param
.
grad
.
detach
())
# Norm parameters.
max_norm
=
float
(
max_norm
)
norm_type
=
float
(
norm_type
)
total_norm
=
0.0
# Calculate norm.
if
norm_type
==
inf
:
total_norm
=
max
(
grad
.
abs
().
max
()
for
grad
in
grads_for_norm
)
total_norm_cuda
=
torch
.
cuda
.
FloatTensor
([
float
(
total_norm
)])
# Take max across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm_cuda
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
model_parallel_group
)
total_norm
=
total_norm_cuda
[
0
].
item
()
else
:
if
norm_type
==
2.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
# Use apex's multi-tensor applier for efficiency reasons.
# Multi-tensor applier takes a function and a list of list
# and performs the operation on that list all in one kernel.
if
grads_for_norm
:
grad_norm
,
_
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm
,
dummy_overflow_buf
,
[
grads_for_norm
],
False
# no per-parameter norm
)
else
:
grad_norm
=
torch
.
cuda
.
FloatTensor
([
0
])
# Since we will be summing across data parallel groups,
# we need the pow(norm-type).
total_norm
=
grad_norm
**
norm_type
else
:
for
grad
in
grads_for_norm
:
grad_norm
=
torch
.
norm
(
grad
,
norm_type
)
total_norm
+=
grad_norm
**
norm_type
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_norm
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_norm
=
total_norm
.
item
()
**
(
1.0
/
norm_type
)
# Scale.
clip_coeff
=
max_norm
/
(
total_norm
+
1.0e-6
)
if
clip_coeff
<
1.0
:
dummy_overflow_buf
=
torch
.
cuda
.
IntTensor
([
0
])
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
dummy_overflow_buf
,
[
grads
,
grads
],
clip_coeff
)
return
total_norm
def
count_zeros_fp32
(
parameters
,
model_parallel_group
):
if
isinstance
(
parameters
,
torch
.
Tensor
):
parameters
=
[
parameters
]
# Filter parameters based on:
# - grad should not be none
# - parameter should not be shared
# - should not be a replica due to tensor model parallelism
total_num_zeros
=
torch
.
cuda
.
FloatTensor
([
0.0
])
for
param
in
parameters
:
grad_not_none
=
param
.
grad
is
not
None
is_not_shared
=
param_is_not_shared
(
param
)
is_not_tp_duplicate
=
param_is_not_tensor_parallel_duplicate
(
param
)
if
grad_not_none
and
is_not_shared
and
is_not_tp_duplicate
:
grad
=
param
.
grad
.
detach
()
num_zeros
=
grad
.
numel
()
-
torch
.
count_nonzero
(
grad
)
total_num_zeros
=
num_zeros
+
total_num_zeros
# Sum across all model-parallel GPUs.
torch
.
distributed
.
all_reduce
(
total_num_zeros
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
model_parallel_group
)
total_num_zeros
=
total_num_zeros
.
item
()
return
total_num_zeros
Prev
1
…
12
13
14
15
16
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment