Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
cdc140f1
Unverified
Commit
cdc140f1
authored
Nov 14, 2022
by
Rick Ho
Committed by
GitHub
Nov 14, 2022
Browse files
Merge pull request #138 from laekov/clip-bug
fix bug: add proper comm group
parents
53844b65
66d27e9a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
4 additions
and
4 deletions
+4
-4
examples/megatron/clip-grad-v2.2.patch
examples/megatron/clip-grad-v2.2.patch
+4
-4
No files found.
examples/megatron/clip-grad-v2.2.patch
View file @
cdc140f1
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index e8d0d02..
fd6660a
100644
index e8d0d02..
369fdf6
100644
--- a/megatron/optimizer/clip_grads.py
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -52,6 +52,7 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
@@ -52,6 +52,7 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
...
@@ -30,7 +30,7 @@ index e8d0d02..fd6660a 100644
...
@@ -30,7 +30,7 @@ index e8d0d02..fd6660a 100644
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
# Take max across all model-parallel GPUs.
@@ -96,7 +101,1
8
@@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
@@ -96,7 +101,1
9
@@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
total_norm = grad_norm ** norm_type
...
@@ -41,7 +41,8 @@ index e8d0d02..fd6660a 100644
...
@@ -41,7 +41,8 @@ index e8d0d02..fd6660a 100644
+ False # no per-parameter norm
+ False # no per-parameter norm
+ )
+ )
+ grad_norm = grad_norm ** norm_type
+ grad_norm = grad_norm ** norm_type
+ torch.distributed.all_reduce(grad_norm)
+ torch.distributed.all_reduce(grad_norm,
+ group=mpu.get_model_parallel_group())
+ total_norm += grad_norm
+ total_norm += grad_norm
+
+
else:
else:
...
@@ -49,4 +50,3 @@ index e8d0d02..fd6660a 100644
...
@@ -49,4 +50,3 @@ index e8d0d02..fd6660a 100644
for grad in grads_for_norm:
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
total_norm += grad_norm ** norm_type
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment