Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a4ef7c40
Commit
a4ef7c40
authored
May 16, 2022
by
Vijay Korthikanti
Browse files
update jit warmup code to handle sequence parallelism
parent
7f9a48ba
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
9 additions
and
3 deletions
+9
-3
megatron/initialize.py
megatron/initialize.py
+9
-3
No files found.
megatron/initialize.py
View file @
a4ef7c40
...
...
@@ -266,7 +266,13 @@ def _warmup_jit_function():
# Warmup fused bias+gelu
bias
=
torch
.
rand
(
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
,
dtype
=
dtype
,
device
=
'cuda'
)
input
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
if
args
.
sequence_parallel
:
seq_length
=
args
.
seq_length
//
mpu
.
get_tensor_model_parallel_world_size
()
else
:
seq_length
=
args
.
seq_length
input
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
ffn_hidden_size
//
args
.
tensor_model_parallel_size
),
dtype
=
dtype
,
device
=
'cuda'
)
# Warmup JIT fusions with the input grad_enable state of both forward
...
...
@@ -278,9 +284,9 @@ def _warmup_jit_function():
del
bias
,
input
,
output
# Warmup fused bias+dropout+add
input
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
input
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
)
residual
=
torch
.
rand
((
args
.
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
residual
=
torch
.
rand
((
seq_length
,
args
.
micro_batch_size
,
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
)
bias
=
torch
.
rand
((
args
.
hidden_size
),
dtype
=
dtype
,
device
=
'cuda'
).
expand_as
(
residual
)
dropout_rate
=
0.1
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment