Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
57064fd6
Commit
57064fd6
authored
Mar 10, 2020
by
Mohammad Shoeybi
Committed by
Raul Puri
Mar 10, 2020
Browse files
memory optimization in mpu cross entropy
parent
ca8dd4ac
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
11 deletions
+13
-11
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+13
-11
No files found.
megatron/mpu/cross_entropy.py
View file @
57064fd6
...
@@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
# Copy so the input remains unchanged.
logits
=
vocab_parallel_logits
.
clone
()
# Maximum value along vocab dimension across all GPUs.
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
)[
0
]
logits_max
=
torch
.
max
(
vocab_parallel_
logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_model_parallel_group
())
group
=
get_model_parallel_group
())
# Subtract the maximum value.
# Subtract the maximum value.
logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
logits
.
exp
()
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Get the partition's vocab indecies
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
...
@@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get predicted-logits = logits[target].
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d
=
logits
.
view
(
-
1
,
partition_vocab_size
)
logits_2d
=
vocab_parallel_
logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
[
target_mask
]
=
0.0
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
# All reduce is needed to get the chunks from other GPUs.
...
@@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
group
=
get_model_parallel_group
())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment