Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
57064fd6
"docs/vscode:/vscode.git/clone" did not exist on "d63a498c3b8484fc5a146658a906bc20aed13de4"
Commit
57064fd6
authored
Mar 10, 2020
by
Mohammad Shoeybi
Committed by
Raul Puri
Mar 10, 2020
Browse files
memory optimization in mpu cross entropy
parent
ca8dd4ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
11 deletions
+13
-11
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+13
-11
No files found.
megatron/mpu/cross_entropy.py
View file @
57064fd6
...
...
@@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
# Copy so the input remains unchanged.
logits
=
vocab_parallel_logits
.
clone
()
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
)[
0
]
logits_max
=
torch
.
max
(
vocab_parallel_
logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_model_parallel_group
())
# Subtract the maximum value.
logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
logits
.
exp
()
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
...
...
@@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d
=
logits
.
view
(
-
1
,
partition_vocab_size
)
logits_2d
=
vocab_parallel_
logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
...
...
@@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment