Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
6b30dfb7
Commit
6b30dfb7
authored
Jun 13, 2023
by
wukong1992
Committed by
Frank Lee
Jul 04, 2023
Browse files
[shardformer] support llama model using shardformer (#3969)
adjust layer attr
parent
45927d55
Changes
4
Show whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
243 additions
and
1 deletion
+243
-1
colossalai/shardformer/layer/dist_crossentropy.py
colossalai/shardformer/layer/dist_crossentropy.py
+1
-1
colossalai/shardformer/policies/autopolicy.py
colossalai/shardformer/policies/autopolicy.py
+14
-0
colossalai/shardformer/policies/llama.py
colossalai/shardformer/policies/llama.py
+122
-0
tests/test_shardformer/test_model/test_shard_llama.py
tests/test_shardformer/test_model/test_shard_llama.py
+106
-0
No files found.
colossalai/shardformer/layer/dist_crossentropy.py
View file @
6b30dfb7
...
@@ -21,7 +21,7 @@ class DistCrossEntropy(Function):
...
@@ -21,7 +21,7 @@ class DistCrossEntropy(Function):
and can be rewrite as:
and can be rewrite as:
loss = log(sum(exp(x[i])) - x[class]
loss = log(sum(exp(x[i])) - x[class]
To avoid the `nan` of log(s
i
m(exp(x[i]))), we minus the max of x[i]
To avoid the `nan` of log(s
u
m(exp(x[i]))), we minus the max of x[i]
Args:
Args:
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
vocab_logits (:class:`torch.Tensor`): The logits of the vocabulary, shape is
...
...
colossalai/shardformer/policies/autopolicy.py
View file @
6b30dfb7
...
@@ -19,6 +19,20 @@ def build_policies():
...
@@ -19,6 +19,20 @@ def build_policies():
from
.bert
import
BertForSequenceClassificationPolicy
from
.bert
import
BertForSequenceClassificationPolicy
auto_policy_dict
[
BertForSequenceClassification
]
=
BertForSequenceClassificationPolicy
auto_policy_dict
[
BertForSequenceClassification
]
=
BertForSequenceClassificationPolicy
from
transformers.models.llama.modeling_llama
import
LlamaModel
from
.llama
import
LlamaPolicy
auto_policy_dict
[
LlamaModel
]
=
LlamaPolicy
from
transformers
import
LlamaForSequenceClassification
from
.llama
import
LlamaForSequenceClassificationPolicy
auto_policy_dict
[
LlamaForSequenceClassification
]
=
LlamaForSequenceClassificationPolicy
from
transformers
import
LlamaForCausalLM
from
.llama
import
LlamaForCausalLMPolicy
auto_policy_dict
[
LlamaForCausalLM
]
=
LlamaForCausalLMPolicy
from
transformers
import
GPT2Model
from
transformers
import
GPT2Model
...
...
colossalai/shardformer/policies/llama.py
0 → 100644
View file @
6b30dfb7
from
dataclasses
import
dataclass
from
typing
import
Any
,
Callable
,
Dict
,
List
,
Tuple
,
Type
import
torch.nn
as
nn
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaModel
import
colossalai.shardformer.layer.layers
as
col_nn
from
.basepolicy
import
Argument
,
Col_Layer
,
Policy
,
Row_Layer
class
LlamaPolicy
(
Policy
):
@
staticmethod
def
argument_policy
(
config
,
world_size
:
int
)
->
Dict
[
nn
.
Module
,
Argument
]:
return
{
LlamaDecoderLayer
:
Argument
(
attr_dict
=
{
"self_attn.hidden_size"
:
config
.
hidden_size
//
world_size
,
"self_attn.num_heads"
:
config
.
num_attention_heads
//
world_size
,
},
param_funcs
=
[
LlamaPolicy
.
attn_layer
,
LlamaPolicy
.
mlp_layer
]),
LlamaModel
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaPolicy
.
embeddings
])
}
@
staticmethod
def
attn_layer
()
->
List
:
return
[
Col_Layer
(
suffix
=
"self_attn.q_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"self_attn.k_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Col_Layer
(
suffix
=
"self_attn.v_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
),
Row_Layer
(
suffix
=
"self_attn.o_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
)
]
@
staticmethod
def
mlp_layer
()
->
List
:
return
[
Col_Layer
(
suffix
=
"mlp.gate_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
),
Col_Layer
(
suffix
=
"mlp.up_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Row
,
gather_output
=
True
,
),
Col_Layer
(
suffix
=
"mlp.down_proj"
,
weight
=
"weight"
,
bias
=
"bias"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
,
),
]
@
staticmethod
def
embeddings
()
->
List
:
return
[
Col_Layer
(
suffix
=
"embed_tokens"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
VocabParallelEmbedding1D
,
)]
from
transformers
import
LlamaForCausalLM
class
LlamaForCausalLMPolicy
(
LlamaPolicy
):
@
staticmethod
def
argument
(
config
,
world_size
):
llamapolicy
=
LlamaPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
LlamaForCausalLM
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaForCausalLMPolicy
.
lm_head
])}
argument
.
update
(
llamapolicy
)
@
staticmethod
def
lm_head
()
->
List
:
return
[
Col_Layer
(
suffix
=
"lm_head"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)]
from
transformers
import
LlamaForSequenceClassification
class
LlamaForSequenceClassificationPolicy
(
LlamaPolicy
):
@
staticmethod
def
argument
(
config
,
world_size
):
llamapolicy
=
LlamaPolicy
.
argument_policy
(
config
,
world_size
)
argument
=
{
LlamaForSequenceClassification
:
Argument
(
attr_dict
=
{},
param_funcs
=
[
LlamaForSequenceClassificationPolicy
.
score
])
}
argument
.
update
(
llamapolicy
)
@
staticmethod
def
score
()
->
List
:
return
[
Col_Layer
(
suffix
=
"score"
,
weight
=
"weight"
,
replace_layer
=
col_nn
.
Linear1D_Col
,
gather_output
=
True
)]
tests/test_shardformer/test_model/test_shard_llama.py
0 → 100644
View file @
6b30dfb7
import
copy
import
os
import
random
import
pytest
import
torch
from
transformers
import
AutoTokenizer
,
LlamaConfig
,
LlamaForCausalLM
,
LlamaModel
,
LlamaTokenizerFast
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
from
colossalai.shardformer.shard
import
ShardConfig
,
shard_model
from
colossalai.testing
import
rerun_if_address_is_in_use
,
spawn
os
.
environ
[
'TRANSFORMERS_NO_ADVISORY_WARNINGS'
]
=
'true'
CONFIG
=
dict
(
parallel
=
dict
(
data
=
1
,
pipeline
=
1
,
tensor
=
dict
(
size
=
4
,
mode
=
'1d'
)),)
tokenizer
=
LlamaTokenizerFast
.
from_pretrained
(
"hf-internal-testing/llama-tokenizer"
)
def
build_model
(
rank
,
world_size
):
cfg
=
LlamaConfig
(
num_hidden_layers
=
16
)
org_model
=
LlamaForCausalLM
(
cfg
)
shardconfig
=
ShardConfig
(
rank
=
rank
,
world_size
=
world_size
,
gather_output
=
True
,
)
org_model
=
org_model
.
to
(
'cuda'
)
org_model_forshard
=
copy
.
deepcopy
(
org_model
)
sharded_model
=
shard_model
(
org_model_forshard
,
shardconfig
).
to
(
'cuda'
)
return
org_model
,
sharded_model
def
check_forward
(
org_model
,
sharded_model
):
input
=
'Hello, my dog is cute'
inputs
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
del
inputs
[
"token_type_ids"
]
del
inputs
[
"attention_mask"
]
#orgin model
org_model
.
eval
()
org_out
=
org_model
(
**
inputs
)
#shard model
sharded_model
.
eval
()
shard_out
=
sharded_model
(
**
inputs
)
assert
torch
.
allclose
(
org_out
[
0
],
shard_out
[
0
],
atol
=
1e-4
),
f
"shard model output is not equal to orgin model output
\n
{
org_out
[
0
]
}
\n
{
shard_out
[
0
]
}
"
def
check_backward
(
org_model
,
sharded_model
):
# prepare input
input
=
'Hello, my dog is cute'
tokenized_input
=
tokenizer
(
input
,
return_tensors
=
'pt'
).
to
(
'cuda'
)
del
tokenized_input
[
"token_type_ids"
]
del
tokenized_input
[
"attention_mask"
]
labels
=
tokenized_input
[
'input_ids'
].
clone
()
labels
[
labels
==
tokenizer
.
pad_token_id
]
=
-
100
tokenized_input
[
'labels'
]
=
labels
#orgin model
org_model
.
train
()
org_out
=
org_model
(
**
tokenized_input
)
org_loss
=
org_out
.
loss
org_loss
.
backward
()
org_grad
=
org_model
.
model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
torch
.
cuda
.
empty_cache
()
#shard model
sharded_model
.
train
()
shard_out
=
sharded_model
(
**
tokenized_input
)
shard_loss
=
shard_out
.
loss
shard_loss
.
backward
()
shard_grad
=
sharded_model
.
model
.
layers
[
0
].
self_attn
.
q_proj
.
weight
.
grad
shard_grad_list
=
[
torch
.
zeros
([
*
shard_grad
.
shape
]).
to
(
'cuda'
)
for
_
in
range
(
4
)]
shard_grad
=
torch
.
distributed
.
all_gather
(
shard_grad_list
,
shard_grad
)
all_shard_grad
=
torch
.
cat
(
shard_grad_list
,
dim
=
0
)
assert
torch
.
allclose
(
org_loss
,
shard_loss
,
atol
=
1e-5
),
f
"shard model loss is not equal to orgin model loss
\n
{
org_loss
}
\n
{
shard_loss
}
"
assert
torch
.
allclose
(
org_grad
,
all_shard_grad
,
atol
=
1e-5
),
f
"shard model grad is not equal to orgin model grad
\n
{
org_grad
}
\n
{
shard_grad
}
"
def
check_llama
(
rank
,
world_size
,
port
):
disable_existing_loggers
()
colossalai
.
launch
(
config
=
CONFIG
,
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
org_model
,
sharded_model
=
build_model
(
rank
,
world_size
)
check_forward
(
org_model
,
sharded_model
)
check_backward
(
org_model
,
sharded_model
)
torch
.
cuda
.
empty_cache
()
@
pytest
.
mark
.
dist
@
rerun_if_address_is_in_use
()
def
test_llama
():
spawn
(
check_llama
,
4
)
if
__name__
==
"__main__"
:
test_llama
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment