Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
FastMoE
Commits
bc4d35b8
Commit
bc4d35b8
authored
Jul 30, 2021
by
Rick Ho
Browse files
add patch for v2.2
parent
bba5f289
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
59 additions
and
0 deletions
+59
-0
examples/megatron/README.md
examples/megatron/README.md
+9
-0
examples/megatron/clip-grad-v2.2.patch
examples/megatron/clip-grad-v2.2.patch
+50
-0
No files found.
examples/megatron/README.md
View file @
bc4d35b8
...
...
@@ -44,6 +44,15 @@ model parallel model group.
from
fmoe.megatron
import
DistributedDataParallel
as
LocalDDP
```
### Fix gradient clipping
Megatron-LM uses gradient normalization, which is incompatible with FastMoE.
Incorrect norm of the gradients lead to inconsistent parameter updates.
Apply
`clip-grad-v2.2.patch`
to fix the issue.
Note that only 2-norm is implemented in the patch. If other norm methods is
used, remember to implement it accordingly.
### Train as usual
Start traning with FastMoE by using the scripts provided by Megatron-LM.
examples/megatron/clip-grad-v2.2.patch
0 → 100644
View file @
bc4d35b8
diff --git a/megatron/optimizer/clip_grads.py b/megatron/optimizer/clip_grads.py
index e8d0d02..df898f9 100644
--- a/megatron/optimizer/clip_grads.py
+++ b/megatron/optimizer/clip_grads.py
@@ -52,6 +52,7 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# - should not be a replica due to tensor model parallelism
grads = []
grads_for_norm = []
+ grads_in_moe = []
for param in parameters:
grad_not_none = param.grad is not None
is_not_shared = not hasattr(param, 'shared') or not param.shared
@@ -63,7 +64,10 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
assert param.grad.type() == 'torch.cuda.FloatTensor'
grads.append(grad)
if grad_not_none and is_not_shared and is_not_tp_duplicate:
- grads_for_norm.append(grad)
+ if hasattr(param, 'dp_comm') and param.dp_comm in ('none'):
+ grads_in_moe.append(grad)
+ else:
+ grads_for_norm.append(grad)
# Norm parameters.
max_norm = float(max_norm)
@@ -72,6 +76,7 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# Calculate norm.
if norm_type == inf:
+ # TODO: moe
total_norm = max(grad.abs().max() for grad in grads_for_norm)
total_norm_cuda = torch.cuda.FloatTensor([float(total_norm)])
# Take max across all model-parallel GPUs.
@@ -96,7 +101,17 @@
def clip_grad_norm_fp32(parameters, max_norm, norm_type=2):
# we need the pow(norm-type).
total_norm = grad_norm ** norm_type
+ grad_norm, _ = multi_tensor_applier(
+ amp_C.multi_tensor_l2norm,
+ dummy_overflow_buf,
+ [grads_in_moe],
+ False # no per-parameter norm
+ )
+ torch.distributed.all_reduce(grad_norm)
+ total_norm += grad_norm ** norm_type
+
else:
+ # TODO: moe
for grad in grads_for_norm:
grad_norm = torch.norm(grad, norm_type)
total_norm += grad_norm ** norm_type
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment