"vscode:/vscode.git/clone" did not exist on "1abd5e781013a085f86586b30a248dc769909668"
Commit 66d27e9a authored by zms1999's avatar zms1999
Browse files

fix bug: add proper comm group

parent 665b99bf
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index e8d0d02..fd6660a 100644 index e8d0d02..369fdf6 100644
--- a/megatron/optimizer/clip_grads.py --- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py +++ b/megatron/optimizer/clip_grads.py
@@ -52,6 +52,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): @@ -52,6 +52,7 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...@@ -30,7 +30,7 @@ index e8d0d02..fd6660a 100644 ...@@ -30,7 +30,7 @@ index e8d0d02..fd6660a 100644
total_norm = max(grad.abs().max() for grad in grads_for_norm) total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)]) total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs. # Take max across all model-parallel GPUs.
@@ -96,7 +101,18 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2): @@ -96,7 +101,19 @@ def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type). # we need the pow(norm-type).
total_norm = grad_norm ** norm_type total_norm = grad_norm ** norm_type
...@@ -41,7 +41,8 @@ index e8d0d02..fd6660a 100644 ...@@ -41,7 +41,8 @@ index e8d0d02..fd6660a 100644
+ False # no per-parameter norm + False # no per-parameter norm
+ ) + )
+ grad_norm = grad_norm ** norm_type + grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm) + torch.distributed.all_reduce(grad_norm,
+ group=mpu.get_model_parallel_group())
+ total_norm += grad_norm + total_norm += grad_norm
+ +
else: else:
...@@ -49,4 +50,3 @@ index e8d0d02..fd6660a 100644 ...@@ -49,4 +50,3 @@ index e8d0d02..fd6660a 100644
for grad in grads_for_norm: for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type) grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type total_norm += grad_norm ** norm_type
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment