Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
apex
Commits
86dfa18d
Commit
86dfa18d
authored
Aug 04, 2022
by
flyingdown
Browse files
replace distributed_fused_lamb.py
parent
719215bd
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
238 deletions
+49
-238
apex/contrib/csrc/multihead_attn/cutlass
apex/contrib/csrc/multihead_attn/cutlass
+0
-1
apex/contrib/optimizers/distributed_fused_lamb.py
apex/contrib/optimizers/distributed_fused_lamb.py
+48
-236
setup.py
setup.py
+1
-1
No files found.
cutlass
@
ed2ed4d6
Compare
ed2ed4d6
...
ed2ed4d6
Subproject commit ed2ed4d667ce95e1371bd62db32b6a114e774336
apex/contrib/optimizers/distributed_fused_lamb.py
View file @
86dfa18d
import
os
import
math
import
math
import
torch
import
torch
import
importlib
import
importlib
...
@@ -88,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -88,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
step_supports_amp_scaling
=
True
,
overlap_reductions
=
True
,
step_supports_amp_scaling
=
True
,
overlap_reductions
=
True
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_group_size
=
0
,
dwu_num_blocks
=
4
,
dwu_num_chunks
=
4
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
dwu_num_rs_pg
=
1
,
dwu_num_ar_pg
=
4
,
dwu_num_ag_pg
=
0
,
e5m2_allgather
=
False
,
verbose
=
False
,
clip_after_ar
=
True
):
e5m2_allgather
=
False
,
verbose
=
False
):
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
defaults
=
dict
(
lr
=
lr
,
bias_correction
=
bias_correction
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
betas
=
betas
,
eps
=
eps
,
weight_decay
=
weight_decay
,
grad_averaging
=
grad_averaging
,
grad_averaging
=
grad_averaging
,
...
@@ -119,7 +118,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -119,7 +118,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_num_chunks
=
dwu_num_chunks
self
.
_num_chunks
=
dwu_num_chunks
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_e5m2_allgather
=
e5m2_allgather
self
.
_verbose
=
verbose
self
.
_verbose
=
verbose
self
.
_clip_after_ar
=
clip_after_ar
self
.
_L2_grad_norm
=
None
self
.
_L2_grad_norm
=
None
self
.
_current_process_group
=
c10d
.
_get_default_group
()
self
.
_current_process_group
=
c10d
.
_get_default_group
()
...
@@ -138,7 +136,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -138,7 +136,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_p
,
self
.
_fp16_g
=
None
,
None
,
None
,
None
,
None
self
.
_fp32_p
,
self
.
_fp32_m
,
self
.
_fp32_v
,
self
.
_fp16_p
,
self
.
_fp16_g
=
None
,
None
,
None
,
None
,
None
import
inspect
import
inspect
#
assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
assert
(
'no_copy'
in
inspect
.
getfullargspec
(
torch
.
distributed
.
reduce_scatter
).
args
),
"This version of c10d does not support no_copy option"
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_rs_pg
=
dwu_num_rs_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
self
.
_num_ar_pg
=
dwu_num_ar_pg
...
@@ -267,48 +265,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -267,48 +265,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_total_param_size
=
p_offset
self
.
_total_param_size
=
p_offset
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
dwu_min_page_size
=
256
*
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_group_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_total_param_size
=
((
self
.
_total_param_size
+
dwu_min_page_size
-
1
)
//
dwu_min_page_size
)
*
dwu_min_page_size
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
def
_lazy_init_stage1
(
self
):
if
self
.
_lazy_init_stage1_done
:
return
p_i
=
0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for
group
in
self
.
param_groups
:
for
p
in
group
[
'params'
]:
torch
.
distributed
.
broadcast
(
p
,
0
)
if
not
p
.
requires_grad
:
continue
def
wrapper
(
param
,
param_i
):
param_tmp
=
param
.
expand_as
(
param
)
grad_acc
=
param_tmp
.
grad_fn
.
next_functions
[
0
][
0
]
def
allreduce_hook
(
*
unused
):
if
not
self
.
_set_flat_param_view
:
if
self
.
_first_step
:
# first time
self
.
_param_order
.
add
(
param_i
)
else
:
idx
=
self
.
_param_order
.
order
.
index
(
param_i
)
self
.
_do_overlapped_reduction
(
idx
,
param
)
else
:
if
not
self
.
_first_step
:
idx
=
self
.
_param_order
.
order
.
index
(
param_i
)
self
.
_do_overlapped_reduction
(
idx
,
param
)
grad_acc
.
register_hook
(
allreduce_hook
)
self
.
_grad_accs
.
append
(
grad_acc
)
wrapper
(
p
,
p_i
)
p_i
+=
1
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_block_size
=
self
.
_total_param_size
//
self
.
_num_blocks
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
self
.
_chunk_size
=
self
.
_block_size
//
self
.
_num_chunks
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
self
.
_shard_size
=
self
.
_chunk_size
//
self
.
_group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_flat_grads
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
float16
,
device
=
'cuda'
)
self
.
_new_params
=
torch
.
zeros
([
self
.
_total_param_size
],
dtype
=
torch
.
uint8
if
self
.
_e5m2_allgather
else
torch
.
float16
,
device
=
'cuda'
)
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
self
.
_mega_shard_size
=
self
.
_num_blocks
*
self
.
_num_chunks
*
self
.
_shard_size
# initialize master weights, moments buffers if not loaded from checkpoint
# initialize master weights, moments buffers if not loaded from checkpoint
if
self
.
_fp32_p
is
None
:
if
self
.
_fp32_p
is
None
:
...
@@ -327,18 +290,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -327,18 +290,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
def
__shardify
(
p
):
def
__shardify
(
p
):
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
return
[
p
[
shard_id
*
self
.
_shard_size
:(
shard_id
+
1
)
*
self
.
_shard_size
]
for
shard_id
in
range
(
self
.
_group_size
)]
list_of_blocks
=
__blockify
(
p
)
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_chunks
=
[
__chunkify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_chunks
=
[
__chunkify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_list_of_shards
=
[[
__shardify
(
chunk
)
for
chunk
in
chunks
]
for
chunks
in
list_of_list_of_chunks
]
list_of_list_of_list_of_shards
=
[[
__shardify
(
chunk
)
for
chunk
in
chunks
]
for
chunks
in
list_of_list_of_chunks
]
return
list_of_blocks
,
list_of_list_of_chunks
,
list_of_list_of_list_of_shards
return
list_of_blocks
,
list_of_list_of_chunks
,
list_of_list_of_list_of_shards
def
_flat_split_no_shards
(
p
):
self
.
_flat_grads_blocks
,
self
.
_flat_grads_chunks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
def
__blockify
(
p
):
return
[
p
[
block_id
*
self
.
_block_size
:(
block_id
+
1
)
*
self
.
_block_size
]
for
block_id
in
range
(
self
.
_num_blocks
)]
def
__chunkify
(
p
):
return
[
p
[
chunk_id
*
self
.
_chunk_size
:(
chunk_id
+
1
)
*
self
.
_chunk_size
]
for
chunk_id
in
range
(
self
.
_num_chunks
)]
list_of_blocks
=
__blockify
(
self
.
_flat_grads
)
list_of_list_of_chunks
=
[
__chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
def
_full_packed_split
(
p
):
def
_full_packed_split
(
p
):
def
__shardify
(
p
):
def
__shardify
(
p
):
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
return
[
p
[
mega_shard
*
self
.
_mega_shard_size
:(
mega_shard
+
1
)
*
self
.
_mega_shard_size
]
for
mega_shard
in
range
(
self
.
_group_size
)]
...
@@ -350,6 +306,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -350,6 +306,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_mega_blocks
=
[
__blockify
(
mega_shard
)
for
mega_shard
in
list_of_mega_shards
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
list_of_list_of_list_of_mega_chunks
=
[[
__chunkify
(
mega_block
)
for
mega_block
in
mega_blocks
]
for
mega_blocks
in
list_of_list_of_mega_blocks
]
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
return
list_of_mega_shards
,
list_of_list_of_mega_blocks
,
list_of_list_of_list_of_mega_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
def
_packed_split
(
p
):
def
_packed_split
(
p
):
def
__packed_blockify
(
p
):
def
__packed_blockify
(
p
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_shard_size
packed_block_size
=
self
.
_num_chunks
*
self
.
_shard_size
...
@@ -360,89 +317,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -360,89 +317,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_blocks
=
__packed_blockify
(
p
)
list_of_blocks
=
__packed_blockify
(
p
)
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
list_of_list_of_chunks
=
[
__packed_chunkify
(
block
)
for
block
in
list_of_blocks
]
return
list_of_blocks
,
list_of_list_of_chunks
return
list_of_blocks
,
list_of_list_of_chunks
def
_split_assign
(
shards
):
packed_block_size
=
self
.
_num_chunks
*
self
.
_shard_size
list_of_list_of_chunks
=
[]
for
block_id
in
range
(
self
.
_num_blocks
):
list_of_chunks
=
[]
for
chunk_id
in
range
(
self
.
_num_chunks
):
#self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
list_of_chunks
.
append
(
shards
[
block_id
][
chunk_id
][
self
.
_rank_in_group
])
list_of_list_of_chunks
.
append
(
list_of_chunks
)
return
list_of_list_of_chunks
self
.
_new_params_mega_shards
,
self
.
_new_params_mega_blocks
,
self
.
_new_params_mega_chunks
=
_full_packed_split
(
self
.
_new_params
)
# this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
self
.
_new_params2_blocks
,
self
.
_new_params2_chunks
,
self
.
_new_params2_shards
=
_flat_split
(
self
.
_new_params
)
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_p_blocks
,
self
.
_fp32_p_chunks
=
_packed_split
(
self
.
_fp32_p
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_m_blocks
,
self
.
_fp32_m_chunks
=
_packed_split
(
self
.
_fp32_m
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp32_v_blocks
,
self
.
_fp32_v_chunks
=
_packed_split
(
self
.
_fp32_v
)
self
.
_fp32_u_blocks
,
self
.
_fp32_u_chunks
=
_packed_split
(
self
.
_fp32_u
)
self
.
_fp32_u_blocks
,
self
.
_fp32_u_chunks
=
_packed_split
(
self
.
_fp32_u
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
self
.
_fp16_p_blocks
,
self
.
_fp16_p_chunks
=
_packed_split
(
self
.
_fp16_p
)
if
self
.
_full_ar
:
# for gradient all-reduce
self
.
_flat_grads_blocks
,
self
.
_flat_grads_chunks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
# for weight update
self
.
_fp16_g_chunks
=
_split_assign
(
self
.
_flat_grads_shards
)
else
:
self
.
_flat_grads_blocks
,
self
.
_flat_grads_chunks
,
self
.
_flat_grads_shards
=
_flat_split
(
self
.
_flat_grads
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
self
.
_fp16_g_blocks
,
self
.
_fp16_g_chunks
=
_packed_split
(
self
.
_fp16_g
)
self
.
_lazy_init_stage1_done
=
True
self
.
_lazy_init_stage1_done
=
True
def
_lazy_init_stage2
(
self
):
if
self
.
_lazy_init_stage2_done
:
return
if
not
self
.
_set_flat_param_view
:
# reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
self
.
_param_order
.
order
.
reverse
()
# re-order model_params, grad_accs, group_properties lists
self
.
_model_params
=
[
self
.
_model_params
[
i
]
for
i
in
self
.
_param_order
.
order
]
self
.
_grad_accs
=
[
self
.
_grad_accs
[
i
]
for
i
in
self
.
_param_order
.
order
]
self
.
_group_properties
=
[
self
.
_group_properties
[
i
]
for
i
in
self
.
_param_order
.
order
]
def
_get_flat_view
(
param
):
if
param
.
is_contiguous
(
memory_format
=
torch
.
channels_last
):
K
,
C
,
H
,
W
=
param
.
shape
pv
=
param
.
as_strided
(
size
=
(
K
,
H
,
W
,
C
),
stride
=
(
H
*
W
*
C
,
W
*
C
,
C
,
1
))
elif
param
.
is_contiguous
(
memory_format
=
torch
.
channels_last_3d
):
K
,
C
,
D
,
H
,
W
=
param
.
shape
pv
=
param
.
as_strided
(
size
=
(
K
,
D
,
H
,
W
,
C
),
stride
=
(
D
*
H
*
W
*
C
,
H
*
W
*
C
,
W
*
C
,
C
,
1
))
else
:
pv
=
param
return
pv
.
view
(
-
1
)
# re-collect grads info (size, offset) after ordering
prev
=
None
p_offset
=
0
self
.
_grads_info
=
[]
self
.
_individual_flat_grads
=
[]
for
i
,
p
in
enumerate
(
self
.
_model_params
):
p_grads_size
=
p
.
numel
()
self
.
_grads_info
.
append
({
"param_grads_size"
:
p_grads_size
,
"param_offset"
:
p_offset
})
self
.
_individual_flat_grads
.
append
(
self
.
_flat_grads
[
p_offset
:
p_offset
+
p_grads_size
].
view_as
(
p
))
# for the first iteration
self
.
_do_overlapped_reduction
(
i
,
p
)
p_offset
+=
p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if
prev
is
not
None
and
(
prev
.
data_ptr
()
+
prev
.
numel
()
*
prev
.
element_size
()
!=
p
.
data_ptr
()):
p_offset
=
((
p_offset
+
63
)
//
64
)
*
64
prev
=
p
self
.
_low_param_i
=
[
0
]
*
self
.
_num_blocks
for
block_id
in
range
(
self
.
_num_blocks
-
1
,
-
1
,
-
1
):
p_i
=
len
(
self
.
_grads_info
)
-
1
while
p_i
>
0
and
self
.
_grads_info
[
p_i
][
"param_offset"
]
>
block_id
*
self
.
_block_size
:
p_i
-=
1
self
.
_low_param_i
[
block_id
]
=
p_i
#print("self._low_param_i", self._low_param_i)
self
.
_lazy_init_stage1_done
=
True
def
_lazy_init_stage2
(
self
):
def
_lazy_init_stage2
(
self
):
if
self
.
_lazy_init_stage2_done
:
return
if
self
.
_lazy_init_stage2_done
:
return
...
@@ -508,8 +391,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -508,8 +391,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
grad_offset
=
clipped_start
-
flat_grad_start
grad_offset
=
clipped_start
-
flat_grad_start
grad_length
=
clipped_end
-
clipped_start
grad_length
=
clipped_end
-
clipped_start
shard_offset
=
clipped_start
-
flat_shard_start
shard_offset
=
clipped_start
-
flat_shard_start
pf
=
_get_flat_view
(
p
)
model_param_fragment
=
p
.
view
(
-
1
)[
grad_offset
:
grad_offset
+
grad_length
]
model_param_fragment
=
pf
[
grad_offset
:
grad_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_chunks
[
shard_id
][
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
new_param_packed_fragment
=
self
.
_new_params_mega_chunks
[
shard_id
][
block_id
][
chunk_id
][
shard_offset
:
shard_offset
+
grad_length
]
if
model_param_fragment
.
dtype
==
torch
.
float16
:
if
model_param_fragment
.
dtype
==
torch
.
float16
:
self
.
_packed_flat_to_model_params_fp16
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
self
.
_packed_flat_to_model_params_fp16
.
append
(
(
new_param_packed_fragment
,
model_param_fragment
)
)
...
@@ -588,7 +470,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -588,7 +470,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return
flush_block
return
flush_block
def
_pipeline_block_reductions
(
self
,
block_id
):
def
_pipeline_block_reductions
(
self
,
block_id
):
if
self
.
_clip_after_ar
:
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
self
.
_flatten_grad_mt
(
1.0
/
self
.
_world_size
)
# Reduction within each node
# Reduction within each node
...
@@ -600,7 +481,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -600,7 +481,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
rs_stream
):
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
Fals
e
)
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
Tru
e
)
# Reduction across nodes for each rank
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
if
self
.
_num_groups
>
1
:
...
@@ -623,52 +504,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -623,52 +504,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
l2_grad_norm_sq
=
self
.
_fp16_g
.
norm
(
dtype
=
torch
.
float32
,
p
=
2
)
**
2
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
torch
.
distributed
.
all_reduce
(
l2_grad_norm_sq
,
group
=
self
.
_l2_grad_norm_pg
)
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
self
.
_L2_grad_norm
=
l2_grad_norm_sq
.
sqrt
()
else
:
# Copy model grads to flat grads buffer
self
.
_flatten_grad_mt
(
1.0
)
# Compute L2 grad norm
self
.
_l2_grad_norm_st
.
wait_stream
(
torch
.
cuda
.
current_stream
())
with
torch
.
cuda
.
stream
(
self
.
_l2_grad_norm_st
):
self
.
_L2_grad_norm
=
self
.
_flat_grads
.
norm
(
dtype
=
torch
.
float16
,
p
=
2
).
float
()
torch
.
cuda
.
current_stream
().
wait_stream
(
self
.
_l2_grad_norm_st
)
# Apply clipping & pre-reduction scaling on grads
loss_scale
=
self
.
global_scale
max_grad_norm
=
loss_scale
*
self
.
defaults
[
'max_grad_norm'
]
coeff
=
max_grad_norm
/
(
1e-6
+
self
.
L2_grad_norm
)
coeff
=
(
coeff
>
1
)
*
self
.
_one
+
(
coeff
<=
1
)
*
coeff
tmp
=
torch
.
cat
(((
self
.
_one
),
(
coeff
)))
index
=
(
coeff
+
1
>
coeff
).
int
()
scale
=
tmp
.
index_select
(
0
,
index
).
half
()
/
self
.
_world_size
self
.
_flat_grads
.
mul_
(
scale
)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works
=
[
None
]
*
self
.
_num_chunks
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
rs_stream
=
self
.
_rs_st
[
glob_chunk_id
%
self
.
_num_rs_pg
]
rs_stream
.
wait_stream
(
torch
.
cuda
.
current_stream
())
rs_stream
.
wait_stream
(
self
.
_l2_grad_norm_st
)
with
torch
.
cuda
.
stream
(
rs_stream
):
works
[
chunk_id
]
=
torch
.
distributed
.
reduce_scatter
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
self
.
_flat_grads_shards
[
block_id
][
chunk_id
],
group
=
self
.
_rs_pg
[
glob_chunk_id
%
self
.
_num_rs_pg
],
async_op
=
True
,
no_copy
=
False
)
# Reduction across nodes for each rank
if
self
.
_num_groups
>
1
:
for
chunk_id
in
range
(
self
.
_num_chunks
):
glob_chunk_id
=
block_id
*
self
.
_num_chunks
+
chunk_id
ar_stream
=
self
.
_ar_st
[
glob_chunk_id
%
self
.
_num_ar_pg
]
with
torch
.
cuda
.
stream
(
ar_stream
):
works
[
chunk_id
].
wait
()
works
[
chunk_id
]
=
torch
.
distributed
.
all_reduce
(
self
.
_fp16_g_chunks
[
block_id
][
chunk_id
],
group
=
self
.
_ar_pg
[
glob_chunk_id
%
self
.
_num_ar_pg
],
async_op
=
True
)
self
.
_reductions_works
[
block_id
]
=
works
if
block_id
==
0
:
for
block_id
in
range
(
self
.
_num_blocks
):
for
chunk_id
in
range
(
self
.
_num_chunks
):
self
.
_reductions_works
[
block_id
][
chunk_id
].
wait
()
def
__compute_contrib_param_norm
(
self
):
def
__compute_contrib_param_norm
(
self
):
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
if
self
.
_contrib_model_param_for_norm_fp16
is
not
None
and
self
.
_contrib_model_param_for_norm_fp32
is
not
None
:
...
@@ -693,20 +528,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -693,20 +528,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def
_pipeline_step
(
self
):
def
_pipeline_step
(
self
):
global_scale
=
self
.
global_scale
global_scale
=
self
.
global_scale
# if clip before ar, set max_grad_norm to 0
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
max_grad_norm
=
self
.
defaults
[
'max_grad_norm'
]
*
self
.
_clip_after_ar
self
.
_completion_st
.
wait_stream
(
self
.
_l2_grad_norm_st
)
global_grad_norm
=
self
.
L2_grad_norm
global_grad_norm
=
self
.
L2_grad_norm
# check global_grad_norm and fill overflow_buf
# check global_grad_norm and fill overflow_buf
is_finite
=
(
global_grad_norm
+
1
>
global_grad_norm
).
int
()
is_finite
=
(
global_grad_norm
+
1
>
global_grad_norm
).
int
()
self
.
_overflow_buf
=
self
.
_one
*
(
is_finite
^
self
.
_one
)
# toggle between 0 and 1
self
.
_overflow_buf
=
self
.
_one
*
(
is_finite
^
self
.
_one
)
# toggle between 0 and 1
torch
.
distributed
.
all_reduce
(
is_finite
,
op
=
torch
.
distributed
.
ReduceOp
.
MIN
,
group
=
self
.
_current_process_group
)
torch
.
distributed
.
all_reduce
(
self
.
_overflow_buf
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
self
.
_current_process_group
)
# increment step counter if no overflow
# increment step counter if no overflow
self
.
_step
+=
is_finite
self
.
_step
+=
is_finite
...
@@ -745,39 +572,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
...
@@ -745,39 +572,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self
.
_contrib_weight_decay
,
self
.
_contrib_weight_decay
,
global_grad_norm
,
global_grad_norm
,
self
.
_use_nvlamb
)
self
.
_use_nvlamb
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
Fals
e
)
torch
.
distributed
.
all_gather
(
self
.
_new_params_mega_shards
,
self
.
_fp16_p
,
group
=
self
.
_ag_pg
[
0
],
no_copy
=
Tru
e
)
def
_flatten_grad_mt
(
self
,
scale
):
def
_flatten_grad_mt
(
self
,
scale
):
if
len
(
self
.
_grads_fp16
)
>
0
:
if
len
(
self
.
_grads_fp16
)
>
0
:
self
.
_overflow_buf
.
zero_
()
self
.
_overflow_buf
.
zero_
()
if
not
self
.
_fused_norm
:
multi_tensor_applier
(
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads_fp16
)),
list
(
zip
(
*
self
.
_grads_fp16
)),
scale
)
scale
)
else
:
self
.
_L2_grad_norm
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm_scale
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads_fp16
)),
scale
,
False
)[
0
].
float
()
self
.
_grads_fp16
=
[]
self
.
_grads_fp16
=
[]
if
len
(
self
.
_grads_fp32
)
>
0
:
if
len
(
self
.
_grads_fp32
)
>
0
:
self
.
_overflow_buf
.
zero_
()
self
.
_overflow_buf
.
zero_
()
if
not
self
.
_fused_norm
:
multi_tensor_applier
(
multi_tensor_applier
(
amp_C
.
multi_tensor_scale
,
amp_C
.
multi_tensor_scale
,
self
.
_overflow_buf
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads_fp32
)),
list
(
zip
(
*
self
.
_grads_fp32
)),
scale
)
scale
)
else
:
self
.
_L2_grad_norm
=
multi_tensor_applier
(
amp_C
.
multi_tensor_l2norm_scale
,
self
.
_overflow_buf
,
list
(
zip
(
*
self
.
_grads_fp32
)),
scale
,
False
)[
0
].
float
()
self
.
_grads_fp32
=
[]
self
.
_grads_fp32
=
[]
def
_do_overlapped_reduction
(
self
,
param_i
,
param
):
def
_do_overlapped_reduction
(
self
,
param_i
,
param
):
...
...
setup.py
View file @
86dfa18d
...
@@ -510,7 +510,7 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
...
@@ -510,7 +510,7 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'-gencode'
)
cc_flag
.
append
(
'arch=compute_86,code=sm_86'
)
cc_flag
.
append
(
'arch=compute_86,code=sm_86'
)
subprocess
.
run
([
"git"
,
"submodule"
,
"update"
,
"--init"
,
"apex/contrib/csrc/multihead_attn/cutlass"
])
#
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
nvcc_args_mha
=
[
'-O3'
,
nvcc_args_mha
=
[
'-O3'
,
'-gencode'
,
'-gencode'
,
'arch=compute_70,code=sm_70'
,
'arch=compute_70,code=sm_70'
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment