Commit 86dfa18d authored by flyingdown's avatar flyingdown
Browse files

replace distributed_fused_lamb.py

parent 719215bd
Subproject commit ed2ed4d667ce95e1371bd62db32b6a114e774336
import os
import math import math
import torch import torch
import importlib import importlib
...@@ -88,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -88,7 +87,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
step_supports_amp_scaling=True, overlap_reductions=True, step_supports_amp_scaling=True, overlap_reductions=True,
dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4, dwu_group_size=0, dwu_num_blocks=4, dwu_num_chunks=4,
dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0, dwu_num_rs_pg=1, dwu_num_ar_pg=4, dwu_num_ag_pg=0,
e5m2_allgather=False, verbose=False, clip_after_ar=True): e5m2_allgather=False, verbose=False):
defaults = dict(lr=lr, bias_correction=bias_correction, defaults = dict(lr=lr, bias_correction=bias_correction,
betas=betas, eps=eps, weight_decay=weight_decay, betas=betas, eps=eps, weight_decay=weight_decay,
grad_averaging=grad_averaging, grad_averaging=grad_averaging,
...@@ -119,7 +118,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -119,7 +118,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._num_chunks = dwu_num_chunks self._num_chunks = dwu_num_chunks
self._e5m2_allgather = e5m2_allgather self._e5m2_allgather = e5m2_allgather
self._verbose = verbose self._verbose = verbose
self._clip_after_ar = clip_after_ar
self._L2_grad_norm = None self._L2_grad_norm = None
self._current_process_group = c10d._get_default_group() self._current_process_group = c10d._get_default_group()
...@@ -138,7 +136,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -138,7 +136,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None self._fp32_p, self._fp32_m, self._fp32_v, self._fp16_p, self._fp16_g = None, None, None, None, None
import inspect import inspect
#assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option" assert ('no_copy' in inspect.getfullargspec(torch.distributed.reduce_scatter).args), "This version of c10d does not support no_copy option"
self._num_rs_pg = dwu_num_rs_pg self._num_rs_pg = dwu_num_rs_pg
self._num_ar_pg = dwu_num_ar_pg self._num_ar_pg = dwu_num_ar_pg
...@@ -267,48 +265,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -267,48 +265,13 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._total_param_size = p_offset self._total_param_size = p_offset
dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size dwu_min_page_size = 256 * self._num_blocks * self._num_chunks * self._group_size
self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size self._total_param_size = ((self._total_param_size + dwu_min_page_size - 1) // dwu_min_page_size) * dwu_min_page_size
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
def _lazy_init_stage1(self):
if self._lazy_init_stage1_done: return
p_i = 0
#self._model_params = []
#self._grad_accs = []
#self._group_properties = []
for group in self.param_groups:
for p in group['params']:
torch.distributed.broadcast(p, 0)
if not p.requires_grad:
continue
def wrapper(param, param_i):
param_tmp = param.expand_as(param)
grad_acc = param_tmp.grad_fn.next_functions[0][0]
def allreduce_hook(*unused):
if not self._set_flat_param_view:
if self._first_step:
# first time
self._param_order.add(param_i)
else:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
else:
if not self._first_step:
idx = self._param_order.order.index(param_i)
self._do_overlapped_reduction(idx, param)
grad_acc.register_hook(allreduce_hook)
self._grad_accs.append(grad_acc)
wrapper(p, p_i)
p_i += 1
self._block_size = self._total_param_size // self._num_blocks self._block_size = self._total_param_size // self._num_blocks
self._chunk_size = self._block_size // self._num_chunks self._chunk_size = self._block_size // self._num_chunks
self._shard_size = self._chunk_size // self._group_size self._shard_size = self._chunk_size // self._group_size
#print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size)) #print("self._net_total_param_size=%d, self._total_param_size=%d, dwu_min_page_size=%d, self._block_size=%d, self._chunk_size=%d, self._shard_size=%d" % (self._net_total_param_size, self._total_param_size,dwu_min_page_size,self._block_size,self._chunk_size,self._shard_size))
self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda') self._flat_grads = torch.zeros([self._total_param_size], dtype=torch.float16, device='cuda')
self._new_params = torch.zeros([self._total_param_size], dtype=torch.uint8 if self._e5m2_allgather else torch.float16, device='cuda')
self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size self._mega_shard_size = self._num_blocks * self._num_chunks * self._shard_size
# initialize master weights, moments buffers if not loaded from checkpoint # initialize master weights, moments buffers if not loaded from checkpoint
if self._fp32_p is None: if self._fp32_p is None:
...@@ -327,18 +290,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -327,18 +290,11 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)] return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
def __shardify(p): def __shardify(p):
return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)] return [p[shard_id*self._shard_size:(shard_id+1)*self._shard_size] for shard_id in range(self._group_size)]
list_of_blocks = __blockify(p) list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks] list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks] list_of_list_of_list_of_shards = [[__shardify(chunk) for chunk in chunks] for chunks in list_of_list_of_chunks]
return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards return list_of_blocks, list_of_list_of_chunks, list_of_list_of_list_of_shards
def _flat_split_no_shards(p): self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
def __blockify(p):
return [p[block_id*self._block_size:(block_id+1)*self._block_size] for block_id in range(self._num_blocks)]
def __chunkify(p):
return [p[chunk_id*self._chunk_size:(chunk_id+1)*self._chunk_size] for chunk_id in range(self._num_chunks)]
list_of_blocks = __blockify(self._flat_grads)
list_of_list_of_chunks = [__chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks
def _full_packed_split(p): def _full_packed_split(p):
def __shardify(p): def __shardify(p):
return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)] return [p[mega_shard*self._mega_shard_size:(mega_shard+1)*self._mega_shard_size] for mega_shard in range(self._group_size)]
...@@ -350,6 +306,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -350,6 +306,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards] list_of_list_of_mega_blocks = [__blockify(mega_shard) for mega_shard in list_of_mega_shards]
list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks] list_of_list_of_list_of_mega_chunks = [[__chunkify(mega_block) for mega_block in mega_blocks] for mega_blocks in list_of_list_of_mega_blocks]
return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks return list_of_mega_shards, list_of_list_of_mega_blocks, list_of_list_of_list_of_mega_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
def _packed_split(p): def _packed_split(p):
def __packed_blockify(p): def __packed_blockify(p):
packed_block_size = self._num_chunks*self._shard_size packed_block_size = self._num_chunks*self._shard_size
...@@ -360,89 +317,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -360,89 +317,15 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
list_of_blocks = __packed_blockify(p) list_of_blocks = __packed_blockify(p)
list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks] list_of_list_of_chunks = [__packed_chunkify(block) for block in list_of_blocks]
return list_of_blocks, list_of_list_of_chunks return list_of_blocks, list_of_list_of_chunks
def _split_assign(shards):
packed_block_size = self._num_chunks*self._shard_size
list_of_list_of_chunks=[]
for block_id in range(self._num_blocks):
list_of_chunks=[]
for chunk_id in range(self._num_chunks):
#self._fp16_g[block_id*packed_block_size+chunk_id*self._shard_size:block_id*packed_block_size+(chunk_id+1)*self._shard_size] = shards[block_id][chunk_id][self._rank_in_group]
list_of_chunks.append( shards[block_id][chunk_id][self._rank_in_group])
list_of_list_of_chunks.append(list_of_chunks)
return list_of_list_of_chunks
self._new_params_mega_shards, self._new_params_mega_blocks, self._new_params_mega_chunks = _full_packed_split(self._new_params)
# this splitting scheme is needed when allgather needs to be split into multiple chunks in a contiguous way
self._new_params2_blocks, self._new_params2_chunks, self._new_params2_shards = _flat_split(self._new_params)
self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p) self._fp32_p_blocks, self._fp32_p_chunks = _packed_split(self._fp32_p)
self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m) self._fp32_m_blocks, self._fp32_m_chunks = _packed_split(self._fp32_m)
self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v) self._fp32_v_blocks, self._fp32_v_chunks = _packed_split(self._fp32_v)
self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u) self._fp32_u_blocks, self._fp32_u_chunks = _packed_split(self._fp32_u)
self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p) self._fp16_p_blocks, self._fp16_p_chunks = _packed_split(self._fp16_p)
if self._full_ar:
# for gradient all-reduce
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
# for weight update
self._fp16_g_chunks = _split_assign(self._flat_grads_shards)
else:
self._flat_grads_blocks, self._flat_grads_chunks, self._flat_grads_shards = _flat_split(self._flat_grads)
self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g) self._fp16_g_blocks, self._fp16_g_chunks = _packed_split(self._fp16_g)
self._lazy_init_stage1_done = True self._lazy_init_stage1_done = True
def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return
if not self._set_flat_param_view:
# reversing is needed for overlapping allreduce and backprop, but currently not supported for flat param view
self._param_order.order.reverse()
# re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
self._group_properties = [self._group_properties[i] for i in self._param_order.order]
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
K, C, H, W = param.shape
pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
elif param.is_contiguous(memory_format=torch.channels_last_3d):
K, C, D, H, W = param.shape
pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
else:
pv = param
return pv.view(-1)
# re-collect grads info (size, offset) after ordering
prev = None
p_offset = 0
self._grads_info = []
self._individual_flat_grads = []
for i, p in enumerate(self._model_params):
p_grads_size = p.numel()
self._grads_info.append({"param_grads_size":p_grads_size, "param_offset":p_offset})
self._individual_flat_grads.append(self._flat_grads[p_offset:p_offset+p_grads_size].view_as(p))
# for the first iteration
self._do_overlapped_reduction(i, p)
p_offset += p_grads_size
# Only enforce 128b alignment (64 * fp16) for non-consecutive parameters
# RNN is one example of consecutive parameters:
# (weight_ih, weight_hh, bias_ih, bias_hh)
if prev is not None and (prev.data_ptr() + prev.numel() * prev.element_size() != p.data_ptr()):
p_offset = ((p_offset + 63) // 64) * 64
prev = p
self._low_param_i = [0]*self._num_blocks
for block_id in range(self._num_blocks-1,-1,-1):
p_i = len(self._grads_info)-1
while p_i > 0 and self._grads_info[p_i]["param_offset"] > block_id*self._block_size:
p_i -= 1
self._low_param_i[block_id] = p_i
#print("self._low_param_i", self._low_param_i)
self._lazy_init_stage1_done = True
def _lazy_init_stage2(self): def _lazy_init_stage2(self):
if self._lazy_init_stage2_done: return if self._lazy_init_stage2_done: return
...@@ -508,8 +391,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -508,8 +391,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
grad_offset = clipped_start - flat_grad_start grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start shard_offset = clipped_start - flat_shard_start
pf = _get_flat_view(p) model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length]
model_param_fragment = pf[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16: if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) ) self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
...@@ -588,7 +470,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -588,7 +470,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
return flush_block return flush_block
def _pipeline_block_reductions(self, block_id): def _pipeline_block_reductions(self, block_id):
if self._clip_after_ar:
self._flatten_grad_mt(1.0/self._world_size) self._flatten_grad_mt(1.0/self._world_size)
# Reduction within each node # Reduction within each node
...@@ -600,7 +481,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -600,7 +481,7 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg] rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream()) rs_stream.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(rs_stream): with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True, no_copy=False) works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True,no_copy=True)
# Reduction across nodes for each rank # Reduction across nodes for each rank
if self._num_groups > 1: if self._num_groups > 1:
...@@ -623,52 +504,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -623,52 +504,6 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2 l2_grad_norm_sq = self._fp16_g.norm(dtype=torch.float32, p=2)**2
torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg) torch.distributed.all_reduce(l2_grad_norm_sq, group=self._l2_grad_norm_pg)
self._L2_grad_norm = l2_grad_norm_sq.sqrt() self._L2_grad_norm = l2_grad_norm_sq.sqrt()
else:
# Copy model grads to flat grads buffer
self._flatten_grad_mt(1.0)
# Compute L2 grad norm
self._l2_grad_norm_st.wait_stream(torch.cuda.current_stream())
with torch.cuda.stream(self._l2_grad_norm_st):
self._L2_grad_norm = self._flat_grads.norm(dtype=torch.float16, p=2).float()
torch.cuda.current_stream().wait_stream(self._l2_grad_norm_st)
# Apply clipping & pre-reduction scaling on grads
loss_scale = self.global_scale
max_grad_norm = loss_scale*self.defaults['max_grad_norm']
coeff = max_grad_norm /(1e-6+self.L2_grad_norm)
coeff = (coeff>1) * self._one + (coeff<=1) * coeff
tmp = torch.cat(((self._one), (coeff)))
index = (coeff+1>coeff).int()
scale = tmp.index_select(0, index).half()/self._world_size
self._flat_grads.mul_(scale)
# Reduction within each node
# Changes gradient format from [block * chunk * shard] to [shard * block * chunk]
# The output format is the same as the fp32 master parameters
works = [None]*self._num_chunks
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
rs_stream = self._rs_st[glob_chunk_id%self._num_rs_pg]
rs_stream.wait_stream(torch.cuda.current_stream())
rs_stream.wait_stream(self._l2_grad_norm_st)
with torch.cuda.stream(rs_stream):
works[chunk_id] = torch.distributed.reduce_scatter(self._fp16_g_chunks[block_id][chunk_id],self._flat_grads_shards[block_id][chunk_id],group=self._rs_pg[glob_chunk_id%self._num_rs_pg],async_op=True, no_copy=False)
# Reduction across nodes for each rank
if self._num_groups > 1:
for chunk_id in range(self._num_chunks):
glob_chunk_id = block_id * self._num_chunks + chunk_id
ar_stream = self._ar_st[glob_chunk_id%self._num_ar_pg]
with torch.cuda.stream(ar_stream):
works[chunk_id].wait()
works[chunk_id] = torch.distributed.all_reduce(self._fp16_g_chunks[block_id][chunk_id],group=self._ar_pg[glob_chunk_id%self._num_ar_pg],async_op=True)
self._reductions_works[block_id] = works
if block_id == 0:
for block_id in range(self._num_blocks):
for chunk_id in range(self._num_chunks):
self._reductions_works[block_id][chunk_id].wait()
def __compute_contrib_param_norm(self): def __compute_contrib_param_norm(self):
if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None: if self._contrib_model_param_for_norm_fp16 is not None and self._contrib_model_param_for_norm_fp32 is not None:
...@@ -693,20 +528,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -693,20 +528,12 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
def _pipeline_step(self): def _pipeline_step(self):
global_scale = self.global_scale global_scale = self.global_scale
# if clip before ar, set max_grad_norm to 0 max_grad_norm = self.defaults['max_grad_norm']
max_grad_norm = self.defaults['max_grad_norm'] * self._clip_after_ar
self._completion_st.wait_stream(self._l2_grad_norm_st)
global_grad_norm = self.L2_grad_norm global_grad_norm = self.L2_grad_norm
# check global_grad_norm and fill overflow_buf # check global_grad_norm and fill overflow_buf
is_finite = (global_grad_norm + 1 > global_grad_norm).int() is_finite = (global_grad_norm + 1 > global_grad_norm).int()
self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1 self._overflow_buf = self._one * (is_finite ^ self._one) # toggle between 0 and 1
torch.distributed.all_reduce(is_finite,
op=torch.distributed.ReduceOp.MIN,
group=self._current_process_group)
torch.distributed.all_reduce(self._overflow_buf,
op=torch.distributed.ReduceOp.MAX,
group=self._current_process_group)
# increment step counter if no overflow # increment step counter if no overflow
self._step += is_finite self._step += is_finite
...@@ -745,39 +572,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -745,39 +572,24 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._contrib_weight_decay, self._contrib_weight_decay,
global_grad_norm, global_grad_norm,
self._use_nvlamb) self._use_nvlamb)
torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0],no_copy=False) torch.distributed.all_gather(self._new_params_mega_shards, self._fp16_p, group=self._ag_pg[0], no_copy=True)
def _flatten_grad_mt(self, scale): def _flatten_grad_mt(self, scale):
if len(self._grads_fp16) > 0: if len(self._grads_fp16) > 0:
self._overflow_buf.zero_() self._overflow_buf.zero_()
if not self._fused_norm:
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_scale, amp_C.multi_tensor_scale,
self._overflow_buf, self._overflow_buf,
list(zip(*self._grads_fp16)), list(zip(*self._grads_fp16)),
scale) scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp16)),
scale, False)[0].float()
self._grads_fp16 = [] self._grads_fp16 = []
if len(self._grads_fp32) > 0: if len(self._grads_fp32) > 0:
self._overflow_buf.zero_() self._overflow_buf.zero_()
if not self._fused_norm:
multi_tensor_applier( multi_tensor_applier(
amp_C.multi_tensor_scale, amp_C.multi_tensor_scale,
self._overflow_buf, self._overflow_buf,
list(zip(*self._grads_fp32)), list(zip(*self._grads_fp32)),
scale) scale)
else:
self._L2_grad_norm=multi_tensor_applier(
amp_C.multi_tensor_l2norm_scale,
self._overflow_buf,
list(zip(*self._grads_fp32)),
scale, False)[0].float()
self._grads_fp32 = [] self._grads_fp32 = []
def _do_overlapped_reduction(self, param_i, param): def _do_overlapped_reduction(self, param_i, param):
......
...@@ -510,7 +510,7 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv: ...@@ -510,7 +510,7 @@ if "--fast_multihead_attn" in sys.argv or "--cuda_ext" in sys.argv:
cc_flag.append('-gencode') cc_flag.append('-gencode')
cc_flag.append('arch=compute_86,code=sm_86') cc_flag.append('arch=compute_86,code=sm_86')
subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"]) #subprocess.run(["git", "submodule", "update", "--init", "apex/contrib/csrc/multihead_attn/cutlass"])
nvcc_args_mha = ['-O3', nvcc_args_mha = ['-O3',
'-gencode', '-gencode',
'arch=compute_70,code=sm_70', 'arch=compute_70,code=sm_70',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment