Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
5d647381
Commit
5d647381
authored
May 18, 2022
by
Vijay Korthikanti
Browse files
t5 regression fixes
parent
a7a12f82
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
32 additions
and
31 deletions
+32
-31
megatron/model/bert_model.py
megatron/model/bert_model.py
+4
-1
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+3
-0
megatron/model/t5_model.py
megatron/model/t5_model.py
+4
-2
megatron/model/transformer.py
megatron/model/transformer.py
+8
-14
megatron/mpu/layers.py
megatron/mpu/layers.py
+13
-14
No files found.
megatron/model/bert_model.py
View file @
5d647381
...
@@ -114,13 +114,16 @@ def post_language_model_processing(lm_output, pooled_output,
...
@@ -114,13 +114,16 @@ def post_language_model_processing(lm_output, pooled_output,
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
(),
binary_logits
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
(),
binary_logits
else
:
else
:
# [b s] => [s b]
# [b s] => [s b]
lm_logits
=
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
lm_labels
=
lm_labels
.
transpose
(
0
,
1
).
contiguous
()
# lm_logits : [s, b, h] and lm_labels: [s, b]
if
fp16_lm_cross_entropy
:
if
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
lm_labels
)
# [s, b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
,
binary_logits
return
lm_loss
,
binary_logits
...
...
megatron/model/gpt_model.py
View file @
5d647381
...
@@ -49,6 +49,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
...
@@ -49,6 +49,9 @@ def post_language_model_processing(lm_output, labels, logit_weights,
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
,
labels
)
else
:
else
:
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
loss
=
mpu
.
vocab_parallel_cross_entropy
(
output
.
float
(),
labels
)
# [s b] => [b, s]
loss
=
loss
.
transpose
(
0
,
1
).
contiguous
()
return
loss
return
loss
...
...
megatron/model/t5_model.py
View file @
5d647381
...
@@ -152,7 +152,7 @@ class T5Model(MegatronModule):
...
@@ -152,7 +152,7 @@ class T5Model(MegatronModule):
if
self
.
post_process
and
self
.
add_decoder
:
if
self
.
post_process
and
self
.
add_decoder
:
decoder_output
,
encoder_output
=
lm_output
decoder_output
,
encoder_output
=
lm_output
# Output.
# Output.
[s, b, h]
lm_logits
=
self
.
lm_head
(
decoder_output
,
lm_logits
=
self
.
lm_head
(
decoder_output
,
self
.
word_embeddings_weight
())
self
.
word_embeddings_weight
())
...
@@ -161,13 +161,15 @@ class T5Model(MegatronModule):
...
@@ -161,13 +161,15 @@ class T5Model(MegatronModule):
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
return
lm_logits
.
transpose
(
0
,
1
).
contiguous
()
else
:
else
:
# [b s] => [s b]
# [b s] => [s b]
lm_labels
=
lm_lab
l
es
.
transpose
(
0
,
1
).
contiguous
()
lm_labels
=
lm_labe
l
s
.
transpose
(
0
,
1
).
contiguous
()
if
self
.
fp16_lm_cross_entropy
:
if
self
.
fp16_lm_cross_entropy
:
assert
lm_logits
.
dtype
==
torch
.
half
assert
lm_logits
.
dtype
==
torch
.
half
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
,
lm_labels
)
else
:
else
:
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_loss
=
mpu
.
vocab_parallel_cross_entropy
(
lm_logits
.
float
(),
lm_labels
)
lm_labels
)
# [s b] => [b s]
lm_loss
=
lm_loss
.
transpose
(
0
,
1
).
contiguous
()
return
lm_loss
return
lm_loss
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
elif
self
.
add_decoder
and
not
self
.
add_encoder
:
decoder_output
,
encoder_output
=
lm_output
decoder_output
,
encoder_output
=
lm_output
...
...
megatron/model/transformer.py
View file @
5d647381
...
@@ -167,7 +167,6 @@ class SwitchMLP(MegatronModule):
...
@@ -167,7 +167,6 @@ class SwitchMLP(MegatronModule):
class
CoreAttention
(
MegatronModule
):
class
CoreAttention
(
MegatronModule
):
matmul_input_buffer
=
None
def
__init__
(
self
,
layer_number
,
def
__init__
(
self
,
layer_number
,
attn_mask_type
=
AttnMaskType
.
padding
):
attn_mask_type
=
AttnMaskType
.
padding
):
...
@@ -235,21 +234,16 @@ class CoreAttention(MegatronModule):
...
@@ -235,21 +234,16 @@ class CoreAttention(MegatronModule):
output_size
[
0
]
*
output_size
[
1
],
-
1
)
output_size
[
0
]
*
output_size
[
1
],
-
1
)
# preallocting input tensor: [b * np, sq, sk]
# preallocting input tensor: [b * np, sq, sk]
if
CoreAttention
.
matmul_input_buffer
is
None
:
matmul_input_buffer
=
torch
.
empty
(
CoreAttention
.
matmul_input_buffer
=
torch
.
empty
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
2
],
output_size
[
3
],
output_size
[
3
],
dtype
=
query_layer
.
dtype
,
dtype
=
query_layer
.
dtype
,
device
=
torch
.
cuda
.
current_device
())
device
=
torch
.
cuda
.
current_device
())
else
:
assert
CoreAttention
.
matmul_input_buffer
.
size
()
==
\
(
output_size
[
0
]
*
output_size
[
1
],
output_size
[
2
],
output_size
[
3
]),
\
"buffer dimensions should remain the same during the training run"
# Raw attention scores. [b * np, sq, sk]
# Raw attention scores. [b * np, sq, sk]
matmul_result
=
torch
.
baddbmm
(
matmul_result
=
torch
.
baddbmm
(
CoreAttention
.
matmul_input_buffer
,
matmul_input_buffer
,
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
query_layer
.
transpose
(
0
,
1
),
# [b * np, sq, hn]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
key_layer
.
transpose
(
0
,
1
).
transpose
(
1
,
2
),
# [b * np, hn, sk]
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
beta
=
0.0
,
alpha
=
(
1.0
/
self
.
norm_factor
))
...
@@ -921,7 +915,7 @@ class ParallelTransformer(MegatronModule):
...
@@ -921,7 +915,7 @@ class ParallelTransformer(MegatronModule):
if
self
.
sequence_parallel
:
if
self
.
sequence_parallel
:
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
rng_context
=
mpu
.
get_cuda_rng_tracker
().
fork
()
else
:
else
:
rng_context
=
nullcontext
rng_context
=
nullcontext
()
with
rng_context
:
with
rng_context
:
# Forward pass.
# Forward pass.
...
...
megatron/mpu/layers.py
View file @
5d647381
...
@@ -205,7 +205,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -205,7 +205,6 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
Linear layer execution with asynchronous communication and gradient accumulation
Linear layer execution with asynchronous communication and gradient accumulation
fusion in backprop.
fusion in backprop.
"""
"""
all_gather_buffer
=
None
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
def
forward
(
ctx
,
input
,
weight
,
bias
,
gradient_accumulation_fusion
,
...
@@ -221,20 +220,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -221,20 +220,15 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
if
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
is
None
:
all_gather_buffer
=
\
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
=
\
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
requires_grad
=
False
)
else
:
assert
list
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
.
size
())
==
dim_size
,
\
"buffer dimensions should remain same during the training run"
torch
.
distributed
.
_all_gather_base
(
torch
.
distributed
.
_all_gather_base
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
all_gather_buffer
,
input
,
input
,
group
=
get_tensor_model_parallel_group
())
group
=
get_tensor_model_parallel_group
())
total_input
=
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
total_input
=
all_gather_buffer
else
:
else
:
total_input
=
input
total_input
=
input
...
@@ -253,15 +247,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
...
@@ -253,15 +247,20 @@ class LinearWithGradAccumulationAndAsyncCommunication(torch.autograd.Function):
dim_size
=
list
(
input
.
size
())
dim_size
=
list
(
input
.
size
())
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
dim_size
[
0
]
=
dim_size
[
0
]
*
world_size
all_gather_buffer
=
\
torch
.
empty
(
dim_size
,
dtype
=
input
.
dtype
,
device
=
torch
.
cuda
.
current_device
(),
requires_grad
=
False
)
handle
=
torch
.
distributed
.
_all_gather_base
(
handle
=
torch
.
distributed
.
_all_gather_base
(
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
,
all_gather_buffer
,
input
,
input
,
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
group
=
get_tensor_model_parallel_group
(),
async_op
=
True
)
# Delay the start of intput gradient computation shortly (3us) to have
# Delay the start of intput gradient computation shortly (3us) to have
# gather scheduled first and have GPU resources allocated
# gather scheduled first and have GPU resources allocated
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
_
=
torch
.
empty
(
1
,
device
=
grad_output
.
device
)
+
1
total_input
=
LinearWithGradAccumulationAndAsyncCommunication
.
all_gather_buffer
total_input
=
all_gather_buffer
else
:
else
:
total_input
=
input
total_input
=
input
grad_input
=
grad_output
.
matmul
(
weight
)
grad_input
=
grad_output
.
matmul
(
weight
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment