Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
cb57c380
Commit
cb57c380
authored
Sep 20, 2021
by
rprenger
Browse files
Fixing merge conflicts
parents
7bdeb1e7
87023abd
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
170 additions
and
144 deletions
+170
-144
megatron/model/gpt_model.py
megatron/model/gpt_model.py
+6
-15
megatron/model/language_model.py
megatron/model/language_model.py
+16
-12
megatron/model/transformer.py
megatron/model/transformer.py
+87
-70
megatron/text_generation_server.py
megatron/text_generation_server.py
+8
-1
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+34
-29
megatron/training.py
megatron/training.py
+18
-16
tools/run_text_generation_server.py
tools/run_text_generation_server.py
+1
-1
No files found.
megatron/model/gpt_model.py
View file @
cb57c380
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
...
@@ -29,23 +29,15 @@ from .utils import scaled_init_method_normal
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
def
post_language_model_processing
(
lm_output
,
labels
,
logit_weights
,
get_key_value
,
parallel_output
,
parallel_output
,
forward_method_parallel_output
,
fp16_lm_cross_entropy
):
fp16_lm_cross_entropy
):
if
get_key_value
:
lm_output
,
presents
=
lm_output
# Output.
# Output.
if
forward_method_parallel_output
is
not
None
:
parallel_output
=
forward_method_parallel_output
output
=
parallel_lm_logits
(
output
=
parallel_lm_logits
(
lm_output
,
lm_output
,
logit_weights
,
logit_weights
,
parallel_output
)
parallel_output
)
if
get_key_value
:
output
=
[
output
,
presents
]
if
labels
is
None
:
if
labels
is
None
:
return
output
return
output
else
:
else
:
...
@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
...
@@ -90,23 +82,22 @@ class GPTModel(MegatronModule):
self
.
language_model
.
set_input_tensor
(
input_tensor
)
self
.
language_model
.
set_input_tensor
(
input_tensor
)
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
def
forward
(
self
,
input_ids
,
position_ids
,
attention_mask
,
labels
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
get_key_value
=
False
,
tokentype_ids
=
None
,
forward_method_parallel_output
=
None
):
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
lm_output
=
self
.
language_model
(
lm_output
=
self
.
language_model
(
input_ids
,
input_ids
,
position_ids
,
position_ids
,
attention_mask
,
attention_mask
,
layer_past
=
layer_past
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
get_key_value
=
get_key_value
)
inference_max_sequence_len
=
inference_max_sequence_len
)
if
self
.
post_process
:
if
self
.
post_process
:
return
post_language_model_processing
(
return
post_language_model_processing
(
lm_output
,
labels
,
lm_output
,
labels
,
self
.
word_embeddings_weight
(),
self
.
word_embeddings_weight
(),
get_key_value
,
self
.
parallel_output
,
self
.
parallel_output
,
forward_method_parallel_output
,
self
.
fp16_lm_cross_entropy
)
self
.
fp16_lm_cross_entropy
)
else
:
else
:
return
lm_output
return
lm_output
...
...
megatron/model/language_model.py
View file @
cb57c380
...
@@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -334,8 +334,10 @@ class TransformerLanguageModel(MegatronModule):
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
def
forward
(
self
,
enc_input_ids
,
enc_position_ids
,
enc_attn_mask
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
dec_input_ids
=
None
,
dec_position_ids
=
None
,
dec_attn_mask
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
layer_past
=
None
,
enc_dec_attn_mask
=
None
,
tokentype_ids
=
None
,
get_key_value
=
False
,
pooling_sequence_index
=
0
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
,
pooling_sequence_index
=
0
,
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
enc_hidden_states
=
None
,
output_enc_hidden
=
False
):
# Embeddings.
# Embeddings.
...
@@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -348,10 +350,11 @@ class TransformerLanguageModel(MegatronModule):
# encoder.
# encoder.
if
enc_hidden_states
is
None
:
if
enc_hidden_states
is
None
:
encoder_output
=
self
.
encoder
(
encoder_input
,
encoder_output
=
self
.
encoder
(
enc_attn_mask
,
encoder_input
,
layer_past
=
layer_past
,
enc_attn_mask
,
get_key_value
=
get_key_value
)
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
else
:
else
:
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
encoder_output
=
enc_hidden_states
.
to
(
encoder_input
.
dtype
)
...
@@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule):
...
@@ -373,12 +376,13 @@ class TransformerLanguageModel(MegatronModule):
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_embedding_output
=
self
.
embedding
(
dec_input_ids
,
dec_position_ids
)
dec_position_ids
)
# decoder
# decoder
decoder_output
=
self
.
decoder
(
dec_embedding_output
,
decoder_output
=
self
.
decoder
(
dec_attn_mask
,
dec_embedding_output
,
layer_past
=
layer_past
,
dec_attn_mask
,
get_key_value
=
get_key_value
,
encoder_output
=
encoder_output
,
encoder_output
=
encoder_output
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
)
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
if
self
.
add_pooler
and
self
.
post_process
:
if
self
.
add_pooler
and
self
.
post_process
:
return
decoder_output
,
encoder_output
,
pooled_output
return
decoder_output
,
encoder_output
,
pooled_output
...
...
megatron/model/transformer.py
View file @
cb57c380
...
@@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule):
...
@@ -118,6 +118,7 @@ class ParallelAttention(MegatronModule):
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
layer_number
=
max
(
1
,
layer_number
)
self
.
attention_type
=
attention_type
self
.
attention_type
=
attention_type
self
.
attn_mask_type
=
attn_mask_type
self
.
attn_mask_type
=
attn_mask_type
self
.
params_dtype
=
args
.
params_dtype
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
projection_size
=
args
.
kv_channels
*
args
.
num_attention_heads
...
@@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule):
...
@@ -178,10 +179,53 @@ class ParallelAttention(MegatronModule):
init_method
=
output_layer_init_method
,
init_method
=
output_layer_init_method
,
skip_bias_add
=
True
)
skip_bias_add
=
True
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
# Inference key-value memory
get_key_value
=
False
,
encoder_output
=
None
):
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
self
.
inference_current_sequence_len
=
0
def
_allocate_memory
(
self
,
inference_max_sequence_len
,
batch_size
):
return
torch
.
empty
(
inference_max_sequence_len
,
batch_size
,
self
.
num_attention_heads_per_partition
,
self
.
hidden_size_per_attention_head
,
dtype
=
self
.
params_dtype
,
device
=
torch
.
cuda
.
current_device
())
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# hidden_states: [sq, b, h]
# hidden_states: [sq, b, h]
# =================================================
# Pre-allocate memory for key-values for inference.
# =================================================
if
set_inference_key_value_memory
:
assert
inference_max_sequence_len
and
inference_max_sequence_len
>
0
self
.
inference_key_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_value_memory
=
self
.
_allocate_memory
(
inference_max_sequence_len
,
hidden_states
.
size
(
1
))
self
.
inference_current_sequence_len
=
0
# Some consistency check.
if
inference_max_sequence_len
:
assert
self
.
inference_current_sequence_len
<
\
self
.
inference_key_memory
.
size
(
0
)
assert
inference_max_sequence_len
==
\
self
.
inference_key_memory
.
size
(
0
)
# This is added for safety. In case inference_max_sequence_len
# is not provided, make sure there is no potential memory left
# from previous inference.
if
not
inference_max_sequence_len
:
self
.
inference_key_memory
=
None
self
.
inference_value_memory
=
None
# =====================
# =====================
# Query, Key, and Value
# Query, Key, and Value
# =====================
# =====================
...
@@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule):
...
@@ -222,18 +266,24 @@ class ParallelAttention(MegatronModule):
self
.
hidden_size_per_attention_head
)
self
.
hidden_size_per_attention_head
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
query_layer
=
query_layer
.
view
(
*
new_tensor_shape
)
# ==================================
# Adjust key and value for inference
# ==================================
if
layer_past
is
not
None
:
# ===================================================
past_key
,
past_value
=
layer_past
# Adjust key, value, and attention mask for inference
key_layer
=
torch
.
cat
((
past_key
.
type_as
(
key_layer
),
# ===================================================
key_layer
),
dim
=
0
)
value_layer
=
torch
.
cat
((
past_value
.
type_as
(
value_layer
),
if
inference_max_sequence_len
:
value_layer
),
dim
=
0
)
# Adjust the range variables.
if
get_key_value
:
start
=
self
.
inference_current_sequence_len
present
=
(
key_layer
,
value_layer
)
self
.
inference_current_sequence_len
+=
key_layer
.
size
(
0
)
end
=
self
.
inference_current_sequence_len
# Copy key and values.
self
.
inference_key_memory
[
start
:
end
,
...]
=
key_layer
self
.
inference_value_memory
[
start
:
end
,
...]
=
value_layer
key_layer
=
self
.
inference_key_memory
[:
end
,
...]
value_layer
=
self
.
inference_value_memory
[:
end
,
...]
# Adjust attention mask
attention_mask
=
attention_mask
[...,
start
:
end
,
:
end
]
# ===================================
# ===================================
# Raw attention scores. [b, np, s, s]
# Raw attention scores. [b, np, s, s]
...
@@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule):
...
@@ -270,22 +320,6 @@ class ParallelAttention(MegatronModule):
# change view to [b, np, sq, sk]
# change view to [b, np, sq, sk]
attention_scores
=
matmul_result
.
view
(
*
output_size
)
attention_scores
=
matmul_result
.
view
(
*
output_size
)
# ==================================================
# Update attention mask for inference. [b, np, sq, sk]
# ==================================================
if
get_key_value
:
with
torch
.
no_grad
():
if
layer_past
is
not
None
:
attention_mask
=
attention_mask
[
...,
attention_scores
.
size
(
3
)
-
1
,
:
attention_scores
.
size
(
3
)].
unsqueeze
(
2
)
else
:
attention_mask
=
attention_mask
[
...,
:
attention_scores
.
size
(
3
),
:
attention_scores
.
size
(
3
)]
# ===========================
# ===========================
# Attention probs and dropout
# Attention probs and dropout
...
@@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule):
...
@@ -341,9 +375,6 @@ class ParallelAttention(MegatronModule):
output
,
bias
=
self
.
dense
(
context_layer
)
output
,
bias
=
self
.
dense
(
context_layer
)
if
get_key_value
:
output
=
[
output
,
present
]
return
output
,
bias
return
output
,
bias
...
@@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -430,21 +461,21 @@ class ParallelTransformerLayer(MegatronModule):
output_layer_init_method
)
output_layer_init_method
)
def
forward
(
self
,
hidden_states
,
attention_mask
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
encoder_output
=
None
,
layer_past
=
None
,
get_key_value
=
False
):
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# hidden_states: [b, s, h]
# hidden_states: [b, s, h]
# Layer norm at the beginning of the transformer layer.
# Layer norm at the beginning of the transformer layer.
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
layernorm_output
=
self
.
input_layernorm
(
hidden_states
)
# Self attention.
# Self attention.
attention_output
,
attention_bias
=
\
attention_output
,
attention_bias
=
\
self
.
self_attention
(
layernorm_output
,
self
.
self_attention
(
attention_mask
,
layernorm_output
,
layer_past
=
layer_past
,
attention_mask
,
get_key_value
=
get_key_value
)
set_inference_key_value_memory
=
set_inference_key_value_memory
,
inference_max_sequence_len
=
inference_max_sequence_len
)
if
get_key_value
:
attention_output
,
presents
=
attention_output
# Residual connection.
# Residual connection.
if
self
.
apply_residual_connection_post_layernorm
:
if
self
.
apply_residual_connection_post_layernorm
:
...
@@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
...
@@ -514,9 +545,6 @@ class ParallelTransformerLayer(MegatronModule):
residual
,
residual
,
self
.
hidden_dropout
)
self
.
hidden_dropout
)
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
return
output
...
@@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule):
...
@@ -659,18 +687,16 @@ class ParallelTransformer(MegatronModule):
forward_step_func"""
forward_step_func"""
self
.
input_tensor
=
input_tensor
self
.
input_tensor
=
input_tensor
def
forward
(
self
,
hidden_states
,
attention_mask
,
layer_past
=
None
,
def
forward
(
self
,
hidden_states
,
attention_mask
,
get_key_value
=
False
,
encoder_output
=
None
,
enc_dec_attn_mask
=
None
):
encoder_output
=
None
,
enc_dec_attn_mask
=
None
,
set_inference_key_value_memory
=
False
,
inference_max_sequence_len
=
None
):
# Checks.
# Checks.
if
layer_past
is
not
None
:
if
inference_max_sequence_len
:
assert
get_key_value
,
\
'for not None values in layer_past, '
\
'expected get_key_value to be set'
if
get_key_value
:
assert
self
.
activations_checkpoint_method
is
None
,
\
assert
self
.
activations_checkpoint_method
is
None
,
\
'get_key_value does not work with '
\
'inference does not work with activation checkpointing'
'activation checkpointing'
if
self
.
pre_process
:
if
self
.
pre_process
:
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
# Data format change to avoid explicit tranposes : [b s h] --> [s b h].
...
@@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule):
...
@@ -693,22 +719,15 @@ class ParallelTransformer(MegatronModule):
encoder_output
,
encoder_output
,
enc_dec_attn_mask
)
enc_dec_attn_mask
)
else
:
else
:
if
get_key_value
:
presents
=
[]
for
index
in
range
(
self
.
num_layers
):
for
index
in
range
(
self
.
num_layers
):
layer
=
self
.
_get_layer
(
index
)
layer
=
self
.
_get_layer
(
index
)
past
=
None
hidden_states
=
layer
(
if
layer_past
is
not
None
:
hidden_states
,
past
=
layer_past
[
index
]
attention_mask
,
hidden_states
=
layer
(
hidden_states
,
encoder_output
=
encoder_output
,
attention_mask
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
encoder_output
=
encoder_output
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
enc_dec_attn_mask
=
enc_dec_attn_mask
,
inference_max_sequence_len
=
inference_max_sequence_len
)
layer_past
=
past
,
get_key_value
=
get_key_value
)
if
get_key_value
:
hidden_states
,
present
=
hidden_states
presents
.
append
(
present
)
# Final layer norm.
# Final layer norm.
if
self
.
post_process
:
if
self
.
post_process
:
...
@@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule):
...
@@ -717,7 +736,5 @@ class ParallelTransformer(MegatronModule):
output
=
self
.
final_layernorm
(
hidden_states
)
output
=
self
.
final_layernorm
(
hidden_states
)
else
:
else
:
output
=
hidden_states
output
=
hidden_states
if
get_key_value
:
output
=
[
output
,
presents
]
return
output
return
output
megatron/text_generation_server.py
View file @
cb57c380
...
@@ -58,6 +58,13 @@ class MegatronGenerate(Resource):
...
@@ -58,6 +58,13 @@ class MegatronGenerate(Resource):
if
not
isinstance
(
all_probs
,
bool
):
if
not
isinstance
(
all_probs
,
bool
):
return
"all_probs must be a boolean value"
return
"all_probs must be a boolean value"
temperature
=
args
.
temperature
if
"temperature"
in
request
.
get_json
():
temperature
=
request
.
get_json
()[
"temperature"
]
if
not
isinstance
(
temperature
,
float
)
or
not
\
0.0
<
temperature
<=
100.0
:
return
"temperature must be a positive float less than or equal to 100.0"
add_BOS
=
False
add_BOS
=
False
if
"add_BOS"
in
request
.
get_json
():
if
"add_BOS"
in
request
.
get_json
():
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
...
@@ -66,7 +73,7 @@ class MegatronGenerate(Resource):
...
@@ -66,7 +73,7 @@ class MegatronGenerate(Resource):
sem
.
acquire
()
# Need to get lock to keep multiple threads from hitting code
sem
.
acquire
()
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
tokens
=
generate
(
self
.
model
,
sentences
,
tokens_to_generate
,
all_probs
,
add_BOS
)
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
tokens
=
generate
(
self
.
model
,
sentences
,
tokens_to_generate
,
all_probs
,
temperature
,
add_BOS
)
sem
.
release
()
sem
.
release
()
if
all_probs
:
if
all_probs
:
...
...
megatron/text_generation_utils.py
View file @
cb57c380
...
@@ -141,14 +141,15 @@ def receive_generate_info():
...
@@ -141,14 +141,15 @@ def receive_generate_info():
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
):
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
,
temperature
):
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
tokens_to_generate
,
tokens_to_generate
,
all_probs
)
all_probs
,
temperature
=
temperature
)
for
tokens
,
lengths
,
output_logits
,
full_logits
in
batch_token_iterator
:
for
tokens
,
lengths
,
output_logits
,
full_logits
in
batch_token_iterator
:
context_length
+=
1
context_length
+=
1
...
@@ -177,16 +178,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
...
@@ -177,16 +178,15 @@ def synced_generate(model, context_tokens_tensor, context_length_tensor, tokens_
if
tokens
is
not
None
:
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
all_probs
=
False
,
add_BOS
=
False
):
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
all_probs
=
False
,
temperature
=
1.0
,
add_BOS
=
False
):
model
.
eval
()
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
)
else
:
else
:
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
=
receive_generate_info
()
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_probs
=
receive_generate_info
()
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
)
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_probs
,
temperature
)
if
output
is
not
None
:
if
output
is
not
None
:
decode_tokens
,
output_logits
,
full_logits
=
output
decode_tokens
,
output_logits
,
full_logits
=
output
...
@@ -230,8 +230,8 @@ def switch(val1, val2, boolean):
...
@@ -230,8 +230,8 @@ def switch(val1, val2, boolean):
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
def
forward_step
(
model
,
tokens
,
position_ids
,
attention_mask
,
tokentype_ids
,
layer_past
=
None
,
get_key_value
=
Non
e
,
set_inference_key_value_memory
=
Fals
e
,
forward_method_parallel_output
=
None
):
inference_max_sequence_len
=
None
):
# Hidden size changes when not using recompute, need to tell p2p_communicate
# Hidden size changes when not using recompute, need to tell p2p_communicate
# functions the correct size
# functions the correct size
...
@@ -246,26 +246,22 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -246,26 +246,22 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
unwrapped_model
=
unwrap_model
(
unwrapped_model
=
unwrap_model
(
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
model
,
(
torchDDP
,
LocalDDP
,
Float16Module
))
unwrapped_model
.
set_input_tensor
(
input_tensor
)
unwrapped_model
.
set_input_tensor
(
input_tensor
)
output_tensor
=
model
(
tokens
,
position_ids
,
attention_mask
,
output_tensor
=
model
(
tokentype_ids
=
tokentype_ids
,
tokens
,
position_ids
,
attention_mask
,
layer_past
=
layer_past
,
tokentype_ids
=
tokentype_ids
,
get_key_value
=
get_key_value
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
forward_method_parallel_output
=
forward_method_parallel_output
)
inference_max_sequence_len
=
inference_max_sequence_len
)
if
get_key_value
:
output_tensor
,
layer_past
=
output_tensor
send_forward
(
output_tensor
)
send_forward
(
output_tensor
)
args
.
seq_length
=
orig_seq_length
args
.
seq_length
=
orig_seq_length
if
get_key_value
:
return
output_tensor
,
layer_past
return
output_tensor
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
tokens_to_generate
,
all_probs
=
False
,
type_ids
=
None
):
tokens_to_generate
,
all_probs
=
False
,
type_ids
=
None
,
temperature
=
None
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -282,7 +278,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -282,7 +278,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
counter
=
0
counter
=
0
layer_past
=
None
batch_size
=
context_tokens
.
size
(
0
)
batch_size
=
context_tokens
.
size
(
0
)
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
is_done
=
torch
.
zeros
([
batch_size
]).
byte
().
cuda
()
tokens
=
context_tokens
tokens
=
context_tokens
...
@@ -299,11 +294,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -299,11 +294,15 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
while
context_length
<
maxlen
:
while
context_length
<
maxlen
:
types2use
=
None
types2use
=
None
if
counter
==
0
:
if
counter
==
0
:
# Allocate memory for the entire context.
set_inference_key_value_memory
=
True
tokens2use
=
tokens
[:,
:
context_length
]
tokens2use
=
tokens
[:,
:
context_length
]
positions2use
=
position_ids
[:,
:
context_length
]
positions2use
=
position_ids
[:,
:
context_length
]
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
:
context_length
]
types2use
=
type_ids
[:,
:
context_length
]
else
:
else
:
# Set this to false so the memory is not reallocated.
set_inference_key_value_memory
=
False
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
...
@@ -311,29 +310,35 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -311,29 +310,35 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
positions2use
,
output
=
forward_step
(
attention_mask
,
model
,
tokens2use
,
layer_past
=
layer_past
,
positions2use
,
get_key_value
=
True
,
attention_mask
,
tokentype_ids
=
types2use
,
set_inference_key_value_memory
=
set_inference_key_value_memory
,
forward_method_parallel_output
=
False
)
inference_max_sequence_len
=
maxlen
,
tokentype_ids
=
types2use
)
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
assert
output
is
not
None
output
=
output
.
float
()
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
else
:
else
:
logits
=
logits
.
float
()
logits
=
logits
.
float
()
logits
/=
args
.
temperature
logits
/=
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
top_p
=
args
.
top_p
)
top_p
=
args
.
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
started
=
context_lengths
<=
context_length
# Clamp the out of vocabulary tokens.
tokenizer
=
get_tokenizer
()
prev
=
torch
.
clamp
(
prev
,
max
=
tokenizer
.
vocab_size
-
1
)
new_tokens
=
switch
(
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
...
...
megatron/training.py
View file @
cb57c380
...
@@ -189,7 +189,7 @@ def update_train_iters(args):
...
@@ -189,7 +189,7 @@ def update_train_iters(args):
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
print_rank_0
(
'setting training iterations to {}'
.
format
(
args
.
train_iters
))
def
get_model
(
model_provider_func
):
def
get_model
(
model_provider_func
,
wrap_with_ddp
=
True
):
"""Build the model."""
"""Build the model."""
args
=
get_args
()
args
=
get_args
()
...
@@ -243,22 +243,24 @@ def get_model(model_provider_func):
...
@@ -243,22 +243,24 @@ def get_model(model_provider_func):
if
args
.
fp16
or
args
.
bf16
:
if
args
.
fp16
or
args
.
bf16
:
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
model
=
[
Float16Module
(
model_module
,
args
)
for
model_module
in
model
]
if
args
.
DDP_impl
==
'torch'
:
if
wrap_with_ddp
:
i
=
torch
.
cuda
.
current_device
()
i
f
args
.
DDP_impl
=
=
'
torch
'
:
model
=
[
torch
DDP
(
model_module
,
device_ids
=
[
i
],
outpu
t_device
=
i
,
i
=
torch
.
cuda
.
curren
t_device
()
process_group
=
mpu
.
get_data_parallel_group
())
model
=
[
torchDDP
(
model_module
,
device_ids
=
[
i
],
output_device
=
i
,
for
model_module
in
model
]
process_group
=
mpu
.
get_data_parallel_group
())
retur
n
model
for
model_module
i
n
model
]
if
args
.
DDP_impl
==
'local'
:
elif
args
.
DDP_impl
==
'local'
:
model
=
[
LocalDDP
(
model_module
,
model
=
[
LocalDDP
(
model_module
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
accumulate_allreduce_grads_in_fp32
,
args
.
use_contiguous_buffers_in_local_ddp
)
args
.
use_contiguous_buffers_in_local_ddp
)
for
model_module
in
model
]
for
model_module
in
model
]
return
model
else
:
raise
NotImplementedError
(
'Unknown DDP implementation specified: {}. '
raise
NotImplementedError
(
'Unknown DDP implementation specified: '
'Exiting.'
.
format
(
args
.
DDP_impl
))
'{}. Exiting.'
.
format
(
args
.
DDP_impl
))
return
model
def
get_learning_rate_scheduler
(
optimizer
):
def
get_learning_rate_scheduler
(
optimizer
):
...
...
tools/run_text_generation_server.py
View file @
cb57c380
...
@@ -65,7 +65,7 @@ if __name__ == "__main__":
...
@@ -65,7 +65,7 @@ if __name__ == "__main__":
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
print
(
"Interleaved pipeline schedule is not yet supported for text generation."
)
exit
()
exit
()
# Set up model and load checkpoint
# Set up model and load checkpoint
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
,
wrap_with_ddp
=
False
)
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment