Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
2e9cbfca
Unverified
Commit
2e9cbfca
authored
Nov 24, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 24, 2022
Browse files
[Gemini] add unitests to check gemini correctness (#2015)
parent
0b0d8f9e
Changes
13
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
133 additions
and
52 deletions
+133
-52
tests/components_to_test/__init__.py
tests/components_to_test/__init__.py
+1
-0
tests/components_to_test/gpt.py
tests/components_to_test/gpt.py
+7
-5
tests/components_to_test/inline_op_model.py
tests/components_to_test/inline_op_model.py
+1
-1
tests/components_to_test/utils/__init__.py
tests/components_to_test/utils/__init__.py
+2
-1
tests/components_to_test/utils/executor.py
tests/components_to_test/utils/executor.py
+15
-0
tests/test_gemini/test_gemini_train.py
tests/test_gemini/test_gemini_train.py
+67
-0
tests/test_gemini/test_mem_tracer.py
tests/test_gemini/test_mem_tracer.py
+3
-13
tests/test_gemini/test_param_op.py
tests/test_gemini/test_param_op.py
+1
-1
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+5
-5
tests/test_gemini/update/test_optim.py
tests/test_gemini/update/test_optim.py
+8
-8
tests/test_gemini/update/test_zerooptim_state_dict.py
tests/test_gemini/update/test_zerooptim_state_dict.py
+2
-2
tests/test_tensor/model/test_gpt2.py
tests/test_tensor/model/test_gpt2.py
+16
-11
tests/test_tensor/test_tp_with_zero.py
tests/test_tensor/test_tp_with_zero.py
+5
-5
No files found.
tests/components_to_test/__init__.py
View file @
2e9cbfca
from
.
import
bert
,
gpt
,
inline_op_model
,
nested_model
,
no_leaf_module
,
repeated_computed_layer
,
resnet
,
simple_net
from
.
import
bert
,
gpt
,
inline_op_model
,
nested_model
,
no_leaf_module
,
repeated_computed_layer
,
resnet
,
simple_net
from
.utils
import
run_fwd_bwd
tests/components_to_test/gpt.py
View file @
2e9cbfca
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
.registry
import
non_distributed_component_funcs
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
from
transformers
import
GPT2Config
,
GPT2LMHeadModel
from
.utils.dummy_data_generator
import
DummyDataGenerator
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.cuda
import
get_current_device
from
.registry
import
non_distributed_component_funcs
from
.utils.dummy_data_generator
import
DummyDataGenerator
class
DummyDataLoader
(
DummyDataGenerator
):
class
DummyDataLoader
(
DummyDataGenerator
):
vocab_size
=
128
vocab_size
=
128
...
@@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator):
...
@@ -15,8 +17,7 @@ class DummyDataLoader(DummyDataGenerator):
input_ids
=
torch
.
randint
(
0
,
input_ids
=
torch
.
randint
(
0
,
DummyDataLoader
.
vocab_size
,
(
DummyDataLoader
.
batch_size
,
DummyDataLoader
.
seq_len
),
DummyDataLoader
.
vocab_size
,
(
DummyDataLoader
.
batch_size
,
DummyDataLoader
.
seq_len
),
device
=
get_current_device
())
device
=
get_current_device
())
attention_mask
=
torch
.
ones_like
(
input_ids
)
return
input_ids
,
input_ids
return
input_ids
,
attention_mask
class
GPTLMModel
(
nn
.
Module
):
class
GPTLMModel
(
nn
.
Module
):
...
@@ -43,8 +44,9 @@ class GPTLMModel(nn.Module):
...
@@ -43,8 +44,9 @@ class GPTLMModel(nn.Module):
if
checkpoint
:
if
checkpoint
:
self
.
model
.
gradient_checkpointing_enable
()
self
.
model
.
gradient_checkpointing_enable
()
def
forward
(
self
,
input_ids
,
attention_mask
):
def
forward
(
self
,
input_ids
):
# Only return lm_logits
# Only return lm_logits
attention_mask
=
torch
.
ones_like
(
input_ids
)
return
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
use_cache
=
not
self
.
checkpoint
)[
0
]
return
self
.
model
(
input_ids
=
input_ids
,
attention_mask
=
attention_mask
,
use_cache
=
not
self
.
checkpoint
)[
0
]
...
...
tests/components_to_test/inline_op_model.py
View file @
2e9cbfca
...
@@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator):
...
@@ -38,7 +38,7 @@ class DummyDataLoader(DummyDataGenerator):
return
data
,
label
return
data
,
label
@
non_distributed_component_funcs
.
register
(
name
=
'inline_op_mod
ul
e'
)
@
non_distributed_component_funcs
.
register
(
name
=
'inline_op_mode
l
'
)
def
get_training_components
():
def
get_training_components
():
def
model_builder
(
checkpoint
=
True
):
def
model_builder
(
checkpoint
=
True
):
...
...
tests/components_to_test/utils/__init__.py
View file @
2e9cbfca
from
.dummy_data_generator
import
DummyDataGenerator
from
.dummy_data_generator
import
DummyDataGenerator
from
.executor
import
run_fwd_bwd
tests/components_to_test/utils/executor.py
0 → 100644
View file @
2e9cbfca
import
torch
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
,
use_init_ctx
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
if
use_init_ctx
:
model
.
backward
(
loss
)
else
:
loss
.
backward
()
tests/test_gemini/test_gemini_train.py
0 → 100644
View file @
2e9cbfca
from
functools
import
partial
import
pytest
import
torch
import
torch.multiprocessing
as
mp
import
colossalai
from
colossalai.logging
import
disable_existing_loggers
,
get_dist_logger
from
colossalai.nn.parallel
import
ZeroDDP
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
,
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_gemini_fwd_bwd
(
rank
,
world_size
,
port
,
model_name
:
str
,
iter_num
=
2
):
PLACEMENT_POLICY
=
'cuda'
disable_existing_loggers
()
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
_
,
_
,
criterion
=
get_components_func
()
# build torch model
model_torch
=
model_builder
(
checkpoint
=
False
).
cuda
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>=
iter_num
:
break
run_fwd_bwd
(
model_torch
,
data
.
cuda
(),
label
.
cuda
(),
criterion
,
False
,
use_init_ctx
=
False
)
# build CAI model
with
ColoInitContext
(
device
=
get_current_device
()):
model
=
model_builder
(
checkpoint
=
False
)
from
colossalai.gemini
import
ChunkManager
,
GeminiManager
,
search_chunk_configuration
config_dict
,
_
=
search_chunk_configuration
(
model
,
search_range_mb
=
1
,
search_interval_byte
=
100
)
chunk_manager
=
ChunkManager
(
config_dict
,
init_device
=
GeminiManager
.
get_default_device
(
PLACEMENT_POLICY
))
gemini_manager
=
GeminiManager
(
PLACEMENT_POLICY
,
chunk_manager
)
model
=
ZeroDDP
(
model
,
gemini_manager
)
model
.
train
()
for
i
,
(
data
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>=
iter_num
:
break
run_fwd_bwd
(
model
,
data
.
cuda
(),
label
.
cuda
(),
criterion
,
False
,
use_init_ctx
=
True
)
for
p1
,
p2
in
zip
(
model
.
parameters
(),
model_torch
.
parameters
()):
torch
.
allclose
(
p1
.
to
(
torch
.
float
),
p2
.
to
(
torch
.
float
))
print
(
f
'pass test
{
model_name
}
'
)
@
pytest
.
mark
.
parametrize
(
"model_name"
,
[
'bert'
])
@
rerun_if_address_is_in_use
()
def
test_gemini_train
(
model_name
,
iter_num
=
2
):
run_func
=
partial
(
run_gemini_fwd_bwd
,
world_size
=
1
,
port
=
free_port
(),
model_name
=
model_name
,
iter_num
=
iter_num
)
mp
.
spawn
(
run_func
,
nprocs
=
1
)
if
__name__
==
'__main__'
:
# for model_name in ["bert", "resnet18", "inline_op_model"]:
# bert, gpt, inline_op_model, nested_model, no_leaf_module,
# repeated_computed_layer, resnet, simple_net
for
model_name
in
[
"nested_model"
,
"no_leaf_module"
]:
test_gemini_train
(
model_name
=
model_name
,
iter_num
=
4
)
tests/test_gemini/test_mem_tracer.py
View file @
2e9cbfca
...
@@ -8,20 +8,10 @@ import colossalai
...
@@ -8,20 +8,10 @@ import colossalai
from
colossalai.gemini.memory_tracer
import
MemtracerWrapper
from
colossalai.gemini.memory_tracer
import
MemtracerWrapper
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
tests.components_to_test
import
run_fwd_bwd
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
def
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
enable_autocast
=
False
):
with
torch
.
cuda
.
amp
.
autocast
(
enabled
=
enable_autocast
):
if
criterion
:
y
=
model
(
data
)
loss
=
criterion
(
y
,
label
)
else
:
loss
=
model
(
data
,
label
)
loss
=
loss
.
float
()
model
.
backward
(
loss
)
def
run_tracer
(
rank
,
world_size
,
port
,
use_grad_check
=
True
):
def
run_tracer
(
rank
,
world_size
,
port
,
use_grad_check
=
True
):
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
colossalai
.
launch
(
config
=
{},
rank
=
rank
,
world_size
=
world_size
,
host
=
'localhost'
,
port
=
port
,
backend
=
'nccl'
)
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'bert'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'bert'
]
...
@@ -43,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
...
@@ -43,7 +33,7 @@ def run_tracer(rank, world_size, port, use_grad_check=True):
data
=
data
.
cuda
()
data
=
data
.
cuda
()
label
=
label
.
cuda
()
label
=
label
.
cuda
()
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
)
run_fwd_bwd
(
model
,
data
,
label
,
criterion
,
False
,
use_init_ctx
=
False
)
model
.
_ophook_list
[
0
].
print_non_model_data
()
model
.
_ophook_list
[
0
].
print_non_model_data
()
...
@@ -58,4 +48,4 @@ def test_tracer(world_size, use_grad_check):
...
@@ -58,4 +48,4 @@ def test_tracer(world_size, use_grad_check):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
test_tracer
(
1
)
test_tracer
(
1
,
True
)
tests/test_gemini/test_param_op.py
View file @
2e9cbfca
...
@@ -50,7 +50,7 @@ def run_model(model, inputs, label, criterion, use_param_hook=False):
...
@@ -50,7 +50,7 @@ def run_model(model, inputs, label, criterion, use_param_hook=False):
def
test_base_param_hook
():
def
test_base_param_hook
():
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'inline_op_mod
ul
e'
]
test_models
=
[
'repeated_computed_layers'
,
'resnet18'
,
'no_leaf_module'
,
'inline_op_mode
l
'
]
# test_models = ['bert']
# test_models = ['bert']
for
model_name
in
test_models
:
for
model_name
in
test_models
:
...
...
tests/test_gemini/update/test_fwd_bwd.py
View file @
2e9cbfca
...
@@ -30,9 +30,9 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -30,9 +30,9 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert
torch
.
allclose
(
p0
,
p1
.
grad
,
atol
=
1e-3
,
rtol
=
1e-5
),
"{}"
.
format
(
torch
.
max
(
torch
.
abs
(
p0
-
p1
.
grad
)).
item
())
assert
torch
.
allclose
(
p0
,
p1
.
grad
,
atol
=
1e-3
,
rtol
=
1e-5
),
"{}"
.
format
(
torch
.
max
(
torch
.
abs
(
p0
-
p1
.
grad
)).
item
())
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -71,16 +71,16 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather):
...
@@ -71,16 +71,16 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather):
torch_model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
if
i
>
0
:
break
break
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
model
.
backward
(
loss
)
model
.
backward
(
loss
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
logits
,
torch_logits
,
rtol
=
0
),
"{} {} {}"
.
format
(
assert
torch
.
allclose
(
logits
,
torch_logits
,
rtol
=
0
),
"{} {} {}"
.
format
(
torch
.
max
(
torch
.
abs
(
logits
-
torch_logits
)).
item
(),
logits
,
torch_logits
)
torch
.
max
(
torch
.
abs
(
logits
-
torch_logits
)).
item
(),
logits
,
torch_logits
)
...
...
tests/test_gemini/update/test_optim.py
View file @
2e9cbfca
...
@@ -37,9 +37,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -37,9 +37,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module):
assert
torch
.
allclose
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
),
"parameter '{}' has problem."
.
format
(
key
)
assert
torch
.
allclose
(
value
,
temp_zero_value
,
rtol
=
1e-3
,
atol
=
1e-2
),
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -83,12 +83,12 @@ def exam_gpt_fwd_bwd(placement_policy):
...
@@ -83,12 +83,12 @@ def exam_gpt_fwd_bwd(placement_policy):
torch_model
.
eval
()
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
,
attn_mask
)
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
# debug_print([0], zero_logits, torch_logits)
# debug_print([0], zero_logits, torch_logits)
...
@@ -127,12 +127,12 @@ def exam_tiny_example(placement_policy):
...
@@ -127,12 +127,12 @@ def exam_tiny_example(placement_policy):
torch_model
.
eval
()
torch_model
.
eval
()
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
,
attn_mask
)
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
# debug_print([0], zero_logits, torch_logits)
# debug_print([0], zero_logits, torch_logits)
...
...
tests/test_gemini/update/test_zerooptim_state_dict.py
View file @
2e9cbfca
...
@@ -50,11 +50,11 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
...
@@ -50,11 +50,11 @@ def exam_zero_optim_state_dict(placement_policy, keep_gathered):
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
set_seed
(
dist
.
get_rank
()
*
3
+
128
)
model
.
train
()
model
.
train
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
0
:
if
i
>
0
:
break
break
optim
.
zero_grad
()
optim
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
optim
.
backward
(
loss
)
optim
.
backward
(
loss
)
...
...
tests/test_tensor/model/test_gpt2.py
View file @
2e9cbfca
import
pytest
from
functools
import
partial
from
functools
import
partial
from
tests.test_tensor.common_utils
import
tensor_equal
,
tensor_shard_equal
,
set_seed
import
pytest
import
torch
import
torch
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
torch.multiprocessing
as
mp
import
torch.multiprocessing
as
mp
from
torch.nn.parallel
import
DistributedDataParallel
as
DDP
import
colossalai
import
colossalai
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
colossalai.tensor
import
ColoTensor
,
ColoTensorSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ShardSpec
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.testing
import
rerun_if_address_is_in_use
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils
import
free_port
from
colossalai.utils
import
free_port
from
colossalai.utils.cuda
import
get_current_device
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.utils.model.colo_init_context
import
ColoInitContext
from
colossalai.tensor
import
ShardSpec
,
ComputePattern
,
ComputeSpec
,
ProcessGroup
,
ColoTensor
,
ColoTensorSpec
from
colossalai.nn.parallel.data_parallel
import
ColoDDP
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.components_to_test.registry
import
non_distributed_component_funcs
from
tests.test_tensor.common_utils
import
split_param_col_tp1d
,
split_param_row_tp1d
,
debug_print
from
tests.test_tensor.common_utils
import
(
debug_print
,
set_seed
,
split_param_col_tp1d
,
split_param_row_tp1d
,
tensor_equal
,
tensor_shard_equal
,
)
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
def
init_1d_row_spec
(
model
,
pg
:
ProcessGroup
):
...
@@ -107,10 +112,10 @@ def run_gpt(init_spec_func, use_ddp):
...
@@ -107,10 +112,10 @@ def run_gpt(init_spec_func, use_ddp):
torch_model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
set_seed
(
pg
.
dp_local_rank
())
torch
.
distributed
.
barrier
()
torch
.
distributed
.
barrier
()
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
colo_input
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
logits
=
model
(
colo_input
,
attn_mask
)
logits
=
model
(
colo_input
)
torch_logits
=
torch_model
(
input_ids
,
attn_mask
)
torch_logits
=
torch_model
(
input_ids
)
assert
tensor_equal
(
torch_logits
,
logits
),
f
"
{
torch_logits
-
logits
}
"
assert
tensor_equal
(
torch_logits
,
logits
),
f
"
{
torch_logits
-
logits
}
"
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
torch_loss
=
criterion
(
torch_logits
,
input_ids
)
...
...
tests/test_tensor/test_tp_with_zero.py
View file @
2e9cbfca
...
@@ -36,9 +36,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
...
@@ -36,9 +36,9 @@ def check_param(model: ZeroDDP, torch_model: torch.nn.Module, pg: ProcessGroup):
"parameter '{}' has problem."
.
format
(
key
)
"parameter '{}' has problem."
.
format
(
key
)
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
,
attn_mask
):
def
run_fwd_bwd
(
model
,
criterion
,
optimizer
,
input_ids
):
optimizer
.
zero_grad
()
optimizer
.
zero_grad
()
logits
=
model
(
input_ids
,
attn_mask
)
logits
=
model
(
input_ids
)
logits
=
logits
.
float
()
logits
=
logits
.
float
()
loss
=
criterion
(
logits
,
input_ids
)
loss
=
criterion
(
logits
,
input_ids
)
optimizer
.
backward
(
loss
)
optimizer
.
backward
(
loss
)
...
@@ -117,12 +117,12 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
...
@@ -117,12 +117,12 @@ def run_gpt(placement_policy, tp_init_spec_func=None):
torch_model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
attn_mask
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
if
i
>
2
:
if
i
>
2
:
break
break
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
input_ids_colo
=
ColoTensor
.
from_torch_tensor
(
input_ids
,
ColoTensorSpec
(
pg
))
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids_colo
,
attn_mask
)
zero_logits
=
run_fwd_bwd
(
model
,
criterion
,
zero_optim
,
input_ids_colo
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
,
attn_mask
)
torch_logits
=
run_fwd_bwd
(
torch_model
,
criterion
,
torch_optim
,
input_ids
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
assert
torch
.
allclose
(
zero_logits
,
torch_logits
,
rtol
=
1e-3
,
atol
=
1e-2
)
zero_optim
.
step
()
zero_optim
.
step
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment