"docs/source/vscode:/vscode.git/clone" did not exist on "91b05e2ec78e44856d90f4258f91d56807227bac"
Commit 333da806 authored by Thorsten Kurth's avatar Thorsten Kurth
Browse files

Wrote a small wrapper function for flat view creation in _lazy_init_stage2 to...

Wrote a small wrapper function for flat view creation in _lazy_init_stage2 to support channels last data formats
parent d6b5ae5d
...@@ -332,6 +332,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -332,6 +332,17 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
self._param_order.order.reverse() self._param_order.order.reverse()
def _get_flat_view(param):
if param.is_contiguous(memory_format=torch.channels_last):
K, C, H, W = param.shape
pv = param.as_strided(size=(K,H,W,C), stride=(H*W*C, W*C, C, 1))
elif param.is_contiguous(memory_format=torch.channels_last_3d):
K, C, D, H, W = param.shape
pv = param.as_strided(size=(K,D,H,W,C), stride=(D*H*W*C, H*W*C, W*C, C, 1))
else:
pv = param
return pv.view(-1)
# re-order model_params, grad_accs, group_properties lists # re-order model_params, grad_accs, group_properties lists
self._model_params = [self._model_params[i] for i in self._param_order.order] self._model_params = [self._model_params[i] for i in self._param_order.order]
self._grad_accs = [self._grad_accs[i] for i in self._param_order.order] self._grad_accs = [self._grad_accs[i] for i in self._param_order.order]
...@@ -392,7 +403,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer): ...@@ -392,7 +403,8 @@ class DistributedFusedLAMB(torch.optim.Optimizer):
grad_offset = clipped_start - flat_grad_start grad_offset = clipped_start - flat_grad_start
grad_length = clipped_end - clipped_start grad_length = clipped_end - clipped_start
shard_offset = clipped_start - flat_shard_start shard_offset = clipped_start - flat_shard_start
model_param_fragment = p.view(-1)[grad_offset:grad_offset+grad_length] pf = _get_flat_view(p)
model_param_fragment = pf[grad_offset:grad_offset+grad_length]
new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length] new_param_packed_fragment = self._new_params_mega_chunks[shard_id][block_id][chunk_id][shard_offset:shard_offset+grad_length]
if model_param_fragment.dtype == torch.float16: if model_param_fragment.dtype == torch.float16:
self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) ) self._packed_flat_to_model_params_fp16.append( (new_param_packed_fragment, model_param_fragment) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment