Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ColossalAI
Commits
96134e7b
Unverified
Commit
96134e7b
authored
Nov 29, 2022
by
Jiarui Fang
Committed by
GitHub
Nov 29, 2022
Browse files
[hotfix] add bert test for gemini fwd bwd (#2035)
parent
0dbcd4a6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
11 additions
and
13 deletions
+11
-13
colossalai/nn/_ops/element_wise.py
colossalai/nn/_ops/element_wise.py
+8
-10
colossalai/nn/parallel/data_parallel.py
colossalai/nn/parallel/data_parallel.py
+1
-1
tests/test_gemini/update/test_fwd_bwd.py
tests/test_gemini/update/test_fwd_bwd.py
+2
-2
No files found.
colossalai/nn/_ops/element_wise.py
View file @
96134e7b
...
...
@@ -34,17 +34,15 @@ def register_elementwise_op(op):
dist_attr
=
input_tensor
.
dist_spec
))
@
colo_op_impl
(
torch
.
relu_
)
def
elementwise_op
(
input_tensor
):
torch
.
relu_
(
input_tensor
.
data
)
return
input_tensor
@
colo_op_impl
(
Tensor
.
add_
)
def
elementwise_op
(
input_tensor
:
ColoTensor
,
*
args
,
**
kwargs
):
input_tensor
=
input_tensor
.
data
.
add_
(
*
args
,
**
kwargs
)
return
input_tensor
# @colo_op_impl(torch.relu_)
# def elementwise_op(input_tensor):
# torch.relu_(input_tensor.data)
# return input_tensor
# @colo_op_impl(Tensor.add_)
# def elementwise_op(input_tensor: ColoTensor, *args, **kwargs):
# input_tensor = input_tensor.data.add_(*args, **kwargs)
# return input_tensor
# Tensor op
register_elementwise_op
(
Tensor
.
abs
)
...
...
colossalai/nn/parallel/data_parallel.py
View file @
96134e7b
...
...
@@ -272,7 +272,7 @@ class ZeroDDP(ColoDDP):
p
.
grad
=
None
def
_post_backward
(
self
):
#
assert self.chunk_manager.accessed_mem == 0
assert
self
.
chunk_manager
.
accessed_mem
==
0
self
.
_setup_grads_ptr
()
self
.
_logger
.
debug
(
f
'comp cuda demand time:
{
self
.
gemini_manager
.
_comp_cuda_demand_time
}
, layout time:
{
self
.
gemini_manager
.
_layout_time
}
, evict time:
{
self
.
gemini_manager
.
_evict_time
}
, CPU->CUDA vol:
{
self
.
gemini_manager
.
_h2d_volume
}
B, CUDA->CPU vol:
{
self
.
gemini_manager
.
_d2h_volume
}
'
...
...
tests/test_gemini/update/test_fwd_bwd.py
View file @
96134e7b
...
...
@@ -33,7 +33,7 @@ def check_grad(model: ZeroDDP, torch_model: torch.nn.Module):
@
parameterize
(
'placement_policy'
,
[
'cuda'
,
'cpu'
,
'auto'
,
'const'
])
@
parameterize
(
'keep_gather'
,
[
False
,
True
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
,
'resnet18'
])
@
parameterize
(
'model_name'
,
[
'gpt2'
,
'bert'
])
@
parameterize
(
'use_grad_checkpoint'
,
[
False
,
True
])
def
exam_gpt_fwd_bwd
(
placement_policy
,
keep_gather
,
model_name
:
str
,
use_grad_checkpoint
:
bool
=
False
):
set_seed
(
42
)
...
...
@@ -78,7 +78,7 @@ def exam_gpt_fwd_bwd(placement_policy, keep_gather, model_name: str, use_grad_ch
torch
.
max
(
torch
.
abs
(
loss
-
torch_loss
)).
item
(),
loss
,
torch_loss
)
# FIXME(1SAA) bert and resnet18 can not pass the check_grad
#
check_grad(model, torch_model)
check_grad
(
model
,
torch_model
)
def
run_dist
(
rank
,
world_size
,
port
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment