Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
66d27e9a
"vscode:/vscode.git/clone" did not exist on "1abd5e781013a085f86586b30a248dc769909668"
Commit
66d27e9a
authored
Nov 14, 2022
by
zms1999
Browse files
fix bug: add proper comm group
parent
665b99bf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
examples/megatron/clip-grad-v2.2.patch
examples/megatron/clip-grad-v2.2.patch
+4
-4
No files found.
examples/megatron/clip-grad-v2.2.patch
View file @
66d27e9a
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index e8d0d02..
fd6660a
100644
index e8d0d02..
369fdf6
100644
--- a/megatron/optimizer/clip_grads.py
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -52,6 +52,7 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
@@ -52,6 +52,7 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...
@@ -30,7 +30,7 @@ index e8d0d02..fd6660a 100644
...
@@ -30,7 +30,7 @@ index e8d0d02..fd6660a 100644
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
# Take max across all model-parallel GPUs.
@@ -96,7 +101,1
8
@@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
@@ -96,7 +101,1
9
@@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
total_norm = grad_norm ** norm_type
...
@@ -41,7 +41,8 @@ index e8d0d02..fd6660a 100644
...
@@ -41,7 +41,8 @@ index e8d0d02..fd6660a 100644
+ False # no per-parameter norm
+ False # no per-parameter norm
+ )
+ )
+ grad_norm = grad_norm ** norm_type
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm)
+ torch.distributed.all_reduce(grad_norm,
+ group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
+ total_norm += grad_norm
+
+
else:
else:
...
@@ -49,4 +50,3 @@ index e8d0d02..fd6660a 100644
...
@@ -49,4 +50,3 @@ index e8d0d02..fd6660a 100644
for grad in grads_for_norm:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
total_norm += grad_norm ** norm_type
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment