Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wxj
Megatron-LM
Commits
57064fd6
"vscode:/vscode.git/clone" did not exist on "cdf5a19b2f64020d44148619de86a520074385d6"
Commit
57064fd6
authored
Mar 10, 2020
by
Mohammad Shoeybi
Committed by
Raul Puri
Mar 10, 2020
Browse files
memory optimization in mpu cross entropy
parent
ca8dd4ac
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
11 deletions
+13
-11
megatron/mpu/cross_entropy.py
megatron/mpu/cross_entropy.py
+13
-11
No files found.
megatron/mpu/cross_entropy.py
View file @
57064fd6
...
@@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -27,21 +27,13 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
def
forward
(
ctx
,
vocab_parallel_logits
,
target
):
# Copy so the input remains unchanged.
logits
=
vocab_parallel_logits
.
clone
()
# Maximum value along vocab dimension across all GPUs.
# Maximum value along vocab dimension across all GPUs.
logits_max
=
torch
.
max
(
logits
,
dim
=-
1
)[
0
]
logits_max
=
torch
.
max
(
vocab_parallel_
logits
,
dim
=-
1
)[
0
]
torch
.
distributed
.
all_reduce
(
logits_max
,
torch
.
distributed
.
all_reduce
(
logits_max
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
op
=
torch
.
distributed
.
ReduceOp
.
MAX
,
group
=
get_model_parallel_group
())
group
=
get_model_parallel_group
())
# Subtract the maximum value.
# Subtract the maximum value.
logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
vocab_parallel_logits
.
sub_
(
logits_max
.
unsqueeze
(
dim
=-
1
))
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
logits
.
exp
()
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Get the partition's vocab indecies
# Get the partition's vocab indecies
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
get_vocab_range
=
VocabUtility
.
vocab_range_from_per_partition_vocab_size
...
@@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -59,11 +51,12 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
# Get predicted-logits = logits[target].
# Get predicted-logits = logits[target].
# For Simplicity, we convert logits to a 2-D tensor with size
# For Simplicity, we convert logits to a 2-D tensor with size
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
# [*, partition-vocab-size] and target to a 1-D tensor of size [*].
logits_2d
=
logits
.
view
(
-
1
,
partition_vocab_size
)
logits_2d
=
vocab_parallel_
logits
.
view
(
-
1
,
partition_vocab_size
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
masked_target_1d
=
masked_target
.
view
(
-
1
)
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
arange_1d
=
torch
.
arange
(
start
=
0
,
end
=
logits_2d
.
size
()[
0
],
device
=
logits_2d
.
device
)
device
=
logits_2d
.
device
)
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
logits_2d
[
arange_1d
,
masked_target_1d
]
predicted_logits_1d
=
predicted_logits_1d
.
clone
().
contiguous
()
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
=
predicted_logits_1d
.
view_as
(
target
)
predicted_logits
[
target_mask
]
=
0.0
predicted_logits
[
target_mask
]
=
0.0
# All reduce is needed to get the chunks from other GPUs.
# All reduce is needed to get the chunks from other GPUs.
...
@@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
...
@@ -71,6 +64,15 @@ class _VocabParallelCrossEntropy(torch.autograd.Function):
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
group
=
get_model_parallel_group
())
# Sum of exponential of logits along vocab dimension across all GPUs.
exp_logits
=
vocab_parallel_logits
torch
.
exp
(
vocab_parallel_logits
,
out
=
exp_logits
)
sum_exp_logits
=
exp_logits
.
sum
(
dim
=-
1
)
torch
.
distributed
.
all_reduce
(
sum_exp_logits
,
op
=
torch
.
distributed
.
ReduceOp
.
SUM
,
group
=
get_model_parallel_group
())
# Loss = log(sum(exp(logits))) - predicted-logit.
# Loss = log(sum(exp(logits))) - predicted-logit.
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
loss
=
torch
.
log
(
sum_exp_logits
)
-
predicted_logits
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment