Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
97824232
"examples/python_rs/vscode:/vscode.git/clone" did not exist on "1af7433bffac503dc3ecbb6834f4baf6e9358c33"
Unverified
Commit
97824232
authored
Dec 07, 2022
by
Jiarui Fang
Committed by
GitHub
Dec 07, 2022
Browse files
[Gemini] remove eval in gemini unittests! (#2092)
parent
7f72eb05
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
6 deletions
+17
-6
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+17
-6
No files found.
tests/test_gemini/update/test_fwd_bwd.py
View file @
97824232
...
@@ -34,18 +34,25 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
...
@@ -34,18 +34,25 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
assert_close
(
p0
,
p1
.
grad
,
rtol
=
1e-3
,
atol
=
5e-5
)
assert_close
(
p0
,
p1
.
grad
,
rtol
=
1e-3
,
atol
=
5e-5
)
@
parameterize
(
'init_device'
,
[
get_current_device
()])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'keep_gather'
,
[
False
,
True
])
@
parameterize
(
'keep_gather'
,
[
False
,
True
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
,
'albert'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
,
'albert'
])
@
parameterize
(
'use_grad_checkpoint'
,
[
False
,
True
])
@
parameterize
(
'use_grad_checkpoint'
,
[
False
,
True
])
def
exam_gpt_fwd_bwd
(
placement_policy
,
keep_gather
,
model_name
:
str
,
use_grad_checkpoint
:
bool
=
False
):
def
exam_gpt_fwd_bwd
(
placement_policy
,
set_seed
(
42
)
keep_gather
,
model_name
:
str
,
use_grad_checkpoint
:
bool
=
False
,
init_device
=
get_current_device
()):
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
get_components_func
=
non_distributed_component_funcs
.
get_callable
(
model_name
)
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
model_builder
,
train_dataloader
,
test_dataloader
,
optimizer_class
,
criterion
=
get_components_func
()
with
ColoInitContext
(
device
=
get_current_device
()):
set_seed
(
42
)
with
ColoInitContext
(
device
=
init_device
):
model
=
model_builder
(
use_grad_checkpoint
)
model
=
model_builder
(
use_grad_checkpoint
)
set_seed
(
42
)
torch_model
=
model_builder
(
use_grad_checkpoint
).
cuda
()
torch_model
=
model_builder
(
use_grad_checkpoint
).
cuda
()
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
for
torch_p
,
p
in
zip
(
torch_model
.
parameters
(),
model
.
parameters
()):
torch_p
.
data
.
copy_
(
p
.
data
)
torch_p
.
data
.
copy_
(
p
.
data
)
...
@@ -66,9 +73,6 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
...
@@ -66,9 +73,6 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
,
torch_optim
=
convert_to_apex_amp
(
torch_model
,
torch_optim
,
amp_config
)
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
torch_model
=
DDP
(
torch_model
,
device_ids
=
[
pg
.
rank
()],
process_group
=
pg
.
dp_process_group
())
model
.
eval
()
torch_model
.
eval
()
set_seed
(
pg
.
dp_local_rank
())
set_seed
(
pg
.
dp_local_rank
())
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
for
i
,
(
input_ids
,
label
)
in
enumerate
(
train_dataloader
):
# you can only test a single fwd + bwd.
# you can only test a single fwd + bwd.
...
@@ -76,7 +80,14 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
...
@@ -76,7 +80,14 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
if
i
>
0
:
if
i
>
0
:
break
break
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
input_ids
,
label
=
input_ids
.
cuda
(),
label
.
cuda
()
torch_optim
.
zero_grad
()
zero_optim
.
zero_grad
()
# set random seed is same as torch_model.eval()
set_seed
(
42
)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
torch_loss
=
run_fwd_bwd
(
torch_model
,
input_ids
,
label
,
criterion
,
torch_optim
)
set_seed
(
42
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
loss
=
run_fwd_bwd
(
model
,
input_ids
,
label
,
criterion
,
zero_optim
)
assert
torch
.
equal
(
torch_loss
,
loss
)
assert
torch
.
equal
(
torch_loss
,
loss
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment