Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
b6555b71
"gallery/others/plot_repurposing_annotations.py" did not exist on "11bd2eaa6d6976129836b329b01d1300babddcc9"
Commit
b6555b71
authored
Sep 28, 2021
by
mshoeybi
Browse files
working
parent
6c40f892
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
19 deletions
+6
-19
megatron/inference/forward_step.py
megatron/inference/forward_step.py
+2
-19
megatron/model/module.py
megatron/model/module.py
+4
-0
No files found.
megatron/inference/forward_step.py
View file @
b6555b71
...
@@ -16,25 +16,10 @@
...
@@ -16,25 +16,10 @@
"""Forward step utilities."""
"""Forward step utilities."""
import
torch
import
torch
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
megatron.p2p_communication
import
recv_forward
,
send_forward
from
.sampling
import
sample
from
megatron
import
get_args
from
megatron
import
mpu
import
torch.nn.functional
as
F
from
megatron
import
print_rank_0
from
megatron
import
get_args
,
get_tokenizer
from
megatron.utils
import
get_ltor_masks_and_position_ids
,
unwrap_model
from
.communication
import
(
broadcast_float_list
,
copy_from_last_to_first_pipeline_stage
,
broadcast_from_last_pipeline_stage
)
from
.tokenization
import
tokenize_prompts
# These are needed to unwrap the model, would be nice to put these in megatron.utils if possible?
from
torch.nn.parallel.distributed
import
DistributedDataParallel
as
torchDDP
from
megatron.model
import
DistributedDataParallel
as
LocalDDP
from
megatron.model
import
Float16Module
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
...
@@ -51,9 +36,7 @@ def forward_step(model, tokens, position_ids, attention_mask,
...
@@ -51,9 +36,7 @@ def forward_step(model, tokens, position_ids, attention_mask,
input_tensor
=
recv_forward
()
input_tensor
=
recv_forward
()
# Forward pass through the model.
# Forward pass through the model.
unwrapped_model
=
unwrap_model
(
model
.
set_input_tensor
(
input_tensor
)
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
tokens
,
position_ids
,
attention_mask
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
...
...
megatron/model/module.py
View file @
b6555b71
...
@@ -166,6 +166,10 @@ class Float16Module(MegatronModule):
...
@@ -166,6 +166,10 @@ class Float16Module(MegatronModule):
self
.
float16_convertor
=
float16_convertor
self
.
float16_convertor
=
float16_convertor
def
set_input_tensor
(
self
,
input_tensor
):
return
self
.
module
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
*
inputs
,
**
kwargs
):
def
forward
(
self
,
*
inputs
,
**
kwargs
):
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
inputs
=
fp32_to_float16
(
inputs
,
self
.
float16_convertor
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment