Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
d46aa964
Commit
d46aa964
authored
Oct 08, 2021
by
Jared Casper
Browse files
Merge branch 'api_change' into 'main'
API improvements. See merge request ADLR/megatron-lm!337
parents
b31e1296
0694205c
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
116 additions
and
91 deletions
+116
-91
megatron/text_generation_server.py
megatron/text_generation_server.py
+45
-20
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+68
-68
tools/text_generation_cli.py
tools/text_generation_cli.py
+3
-3
No files found.
megatron/text_generation_server.py
View file @
d46aa964
...
@@ -39,10 +39,19 @@ class MegatronGenerate(Resource):
...
@@ -39,10 +39,19 @@ class MegatronGenerate(Resource):
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
"request IP: "
+
str
(
request
.
remote_addr
))
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
json
.
dumps
(
request
.
get_json
()),
flush
=
True
)
print
(
"current time: "
,
datetime
.
datetime
.
now
())
print
(
"current time: "
,
datetime
.
datetime
.
now
())
if
not
"prompts"
in
request
.
get_json
():
return
"prompts argument required"
,
400
sentences
=
request
.
get_json
()[
"sentences"
]
if
"max_len"
in
request
.
get_json
():
if
len
(
sentences
)
>
128
:
return
"max_len is no longer used. Replace with tokens_to_generate"
,
400
return
"Maximum number of sentences is 128"
,
400
if
"sentences"
in
request
.
get_json
():
return
"sentences is no longer used. Replace with prompts"
,
400
prompts
=
request
.
get_json
()[
"prompts"
]
if
len
(
prompts
)
>
128
:
return
"Maximum number of prompts is 128"
,
400
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
tokens_to_generate
=
64
# Choosing hopefully sane default. Full sequence is slow
if
"tokens_to_generate"
in
request
.
get_json
():
if
"tokens_to_generate"
in
request
.
get_json
():
...
@@ -52,11 +61,11 @@ class MegatronGenerate(Resource):
...
@@ -52,11 +61,11 @@ class MegatronGenerate(Resource):
if
tokens_to_generate
<
1
:
if
tokens_to_generate
<
1
:
return
"tokens_to_generate must be an integer greater than 0"
return
"tokens_to_generate must be an integer greater than 0"
all_
probs
=
False
log
probs
=
False
if
"
all_
probs"
in
request
.
get_json
():
if
"
log
probs"
in
request
.
get_json
():
all_
probs
=
request
.
get_json
()[
"
all_
probs"
]
log
probs
=
request
.
get_json
()[
"
log
probs"
]
if
not
isinstance
(
all_
probs
,
bool
):
if
not
isinstance
(
log
probs
,
bool
):
return
"
all_
probs must be a boolean value"
return
"
log
probs must be a boolean value"
temperature
=
args
.
temperature
temperature
=
args
.
temperature
if
"temperature"
in
request
.
get_json
():
if
"temperature"
in
request
.
get_json
():
...
@@ -66,6 +75,22 @@ class MegatronGenerate(Resource):
...
@@ -66,6 +75,22 @@ class MegatronGenerate(Resource):
if
not
(
0.0
<
temperature
<=
100.0
):
if
not
(
0.0
<
temperature
<=
100.0
):
return
"temperature must be a positive number less than or equal to 100.0"
return
"temperature must be a positive number less than or equal to 100.0"
top_k
=
args
.
top_k
if
"top_k"
in
request
.
get_json
():
top_k
=
request
.
get_json
()[
"top_k"
]
if
not
(
type
(
top_k
)
==
int
):
return
"top_k must be an integer equal to or greater than 0 and less than or equal to 1000"
if
not
(
0
<
top_k
<=
1000
):
return
"top_k must be equal to or greater than 0 and less than or equal to 1000"
top_p
=
args
.
top_p
if
"top_p"
in
request
.
get_json
():
top_p
=
request
.
get_json
()[
"top_p"
]
if
not
(
type
(
top_p
)
==
float
):
return
"top_p must be a positive float less than or equal to 1.0"
if
not
(
0
<
top_p
<=
1.0
):
return
"top_p must be less than or equal to 1.0"
add_BOS
=
False
add_BOS
=
False
if
"add_BOS"
in
request
.
get_json
():
if
"add_BOS"
in
request
.
get_json
():
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
add_BOS
=
request
.
get_json
()[
"add_BOS"
]
...
@@ -74,24 +99,24 @@ class MegatronGenerate(Resource):
...
@@ -74,24 +99,24 @@ class MegatronGenerate(Resource):
with
lock
:
# Need to get lock to keep multiple threads from hitting code
with
lock
:
# Need to get lock to keep multiple threads from hitting code
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
MegatronGenerate
.
send_do_generate
()
# Tell other ranks we're doing generate
resp
_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
tokens
=
generate
(
self
.
model
,
sentences
,
tokens_to_generate
,
all_probs
,
temperature
,
add_BOS
)
resp
onse
,
response_seg
,
response_logprobs
=
generate
(
self
.
model
,
prompts
,
if
all_probs
:
tokens_to_generate
,
return
jsonify
({
"sentences"
:
resp_sentence
s
,
logprob
s
,
"segments"
:
resp_sentences_seg
,
temperature
,
"logits"
:
output_logits
,
top_k
,
"all_logits"
:
full_logits
,
top_p
,
"tokens"
:
tokens
})
add_BOS
)
return
jsonify
({
"
sentences"
:
resp_sentences
,
return
jsonify
({
"
text"
:
response
,
"segments"
:
resp
_sentences
_seg
,
"segments"
:
resp
onse
_seg
,
"log
its"
:
output_logit
s
})
"log
probs"
:
response_logprob
s
})
class
MegatronServer
(
object
):
class
MegatronServer
(
object
):
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
self
.
app
=
Flask
(
__name__
,
static_url_path
=
''
)
self
.
app
=
Flask
(
__name__
,
static_url_path
=
''
)
api
=
Api
(
self
.
app
)
api
=
Api
(
self
.
app
)
api
.
add_resource
(
MegatronGenerate
,
'/
generate
'
,
resource_class_args
=
[
model
])
api
.
add_resource
(
MegatronGenerate
,
'/
api
'
,
resource_class_args
=
[
model
])
def
run
(
self
,
url
):
def
run
(
self
,
url
):
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
)
self
.
app
.
run
(
url
,
threaded
=
True
,
debug
=
False
)
megatron/text_generation_utils.py
View file @
d46aa964
...
@@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS):
...
@@ -108,12 +108,12 @@ def tokenize_batch(sentences, max_len, add_BOS):
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
return
context_tokens_tensor
,
context_length_tensor
return
context_tokens_tensor
,
context_length_tensor
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
):
def
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
):
"""
"""
Needs to be synced up with receive_generate_info
Needs to be synced up with receive_generate_info
"""
"""
# Send the sizes of the tensors
# Send the sizes of the tensors
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
all_
probs
,
temperature
]
input_info
=
[
context_tokens_tensor
.
size
(
0
),
context_tokens_tensor
.
size
(
1
),
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
]
input_info_tensor
=
torch
.
cuda
.
FloatTensor
(
input_info
)
input_info_tensor
=
torch
.
cuda
.
FloatTensor
(
input_info
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
...
@@ -125,13 +125,15 @@ def receive_generate_info():
...
@@ -125,13 +125,15 @@ def receive_generate_info():
"""
"""
Needs to be synced up with send_generate_info
Needs to be synced up with send_generate_info
"""
"""
input_info_tensor
=
torch
.
empty
(
5
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
input_info_tensor
=
torch
.
empty
(
7
,
dtype
=
torch
.
float32
,
device
=
torch
.
cuda
.
current_device
())
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
torch
.
distributed
.
broadcast
(
input_info_tensor
,
0
)
batch_size
=
int
(
input_info_tensor
[
0
].
item
())
batch_size
=
int
(
input_info_tensor
[
0
].
item
())
seq_len
=
int
(
input_info_tensor
[
1
].
item
())
seq_len
=
int
(
input_info_tensor
[
1
].
item
())
tokens_to_generate
=
int
(
input_info_tensor
[
2
].
item
())
tokens_to_generate
=
int
(
input_info_tensor
[
2
].
item
())
all_
probs
=
int
(
input_info_tensor
[
3
].
item
())
log
probs
=
bool
(
input_info_tensor
[
3
].
item
())
temperature
=
float
(
input_info_tensor
[
4
].
item
())
temperature
=
float
(
input_info_tensor
[
4
].
item
())
top_k
=
int
(
input_info_tensor
[
5
].
item
())
top_p
=
float
(
input_info_tensor
[
6
].
item
())
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_length_tensor
=
torch
.
empty
(
batch_size
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
context_tokens_tensor
=
torch
.
empty
(
batch_size
,
seq_len
,
dtype
=
torch
.
int64
,
device
=
torch
.
cuda
.
current_device
())
...
@@ -140,56 +142,53 @@ def receive_generate_info():
...
@@ -140,56 +142,53 @@ def receive_generate_info():
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_length_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
torch
.
distributed
.
broadcast
(
context_tokens_tensor
,
0
)
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
return
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
):
def
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
):
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
)
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
context_length_tensor
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
tokens_to_generate
,
tokens_to_generate
,
all_probs
,
logprobs
,
temperature
=
temperature
)
temperature
,
for
tokens
,
lengths
,
output_logits
,
full_logits
in
batch_token_iterator
:
top_k
,
top_p
)
for
tokens
,
lengths
,
output_logits
in
batch_token_iterator
:
context_length
+=
1
context_length
+=
1
if
mpu
.
is_pipeline_last_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
else
:
if
logprobs
:
if
mpu
.
is_pipeline_
fir
st_stage
():
if
mpu
.
is_pipeline_
la
st_stage
():
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
all_probs
:
else
:
args
=
get_args
()
if
mpu
.
is_pipeline_first_stage
()
:
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
full_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
,
args
.
padded_vocab_size
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
output_logits
=
torch
.
empty
(
tokens
.
size
(
0
),
context_length
-
1
,
dtype
=
torch
.
float32
,
device
=
torch
.
device
(
"cuda"
))
torch
.
distributed
.
broadcast
(
full_logits
,
src
,
group
)
torch
.
distributed
.
broadcast
(
output_logits
,
src
,
group
)
if
tokens
is
not
None
:
if
tokens
is
not
None
:
return
tokens
[:,
:
context_length
],
output_logits
,
full_logits
return
tokens
[:,
:
context_length
],
output_logits
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
all_
probs
=
False
,
temperature
=
1.0
,
add_BOS
=
False
):
def
generate
(
model
,
sentences
=
None
,
tokens_to_generate
=
0
,
log
probs
=
False
,
temperature
=
1.0
,
top_k
=
0
,
top_p
=
0.0
,
add_BOS
=
False
):
model
.
eval
()
model
.
eval
()
if
torch
.
distributed
.
get_rank
()
==
0
:
if
torch
.
distributed
.
get_rank
()
==
0
:
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
context_tokens_tensor
,
context_length_tensor
=
tokenize_batch
(
sentences
,
tokens_to_generate
,
add_BOS
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
)
send_generate_info
(
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
)
else
:
else
:
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
=
receive_generate_info
()
context_length_tensor
,
context_tokens_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
=
receive_generate_info
()
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
all_
probs
,
temperature
)
output
=
synced_generate
(
model
,
context_tokens_tensor
,
context_length_tensor
,
tokens_to_generate
,
log
probs
,
temperature
,
top_k
,
top_p
)
if
output
is
not
None
:
if
output
is
not
None
:
decode_tokens
,
output_logits
,
full_logits
=
output
decode_tokens
,
output_logits
=
output
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
...
@@ -197,7 +196,8 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
resp_sentences_seg
=
[]
resp_sentences_seg
=
[]
decode_tokens
=
decode_tokens
.
cpu
().
numpy
().
tolist
()
decode_tokens
=
decode_tokens
.
cpu
().
numpy
().
tolist
()
for
decode_token
in
decode_tokens
:
for
i
,
decode_token
in
enumerate
(
decode_tokens
):
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
resp_sentences
.
append
(
tokenizer
.
detokenize
(
decode_token
))
words
=
[]
words
=
[]
for
token
in
decode_token
:
for
token
in
decode_token
:
...
@@ -205,12 +205,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
...
@@ -205,12 +205,10 @@ def generate(model, sentences=None, tokens_to_generate=0, all_probs=False, tempe
word
=
bytearray
([
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
word
=
bytearray
([
tokenizer
.
tokenizer
.
byte_decoder
[
c
]
for
c
in
word
]).
decode
(
'utf-8'
,
errors
=
'replace'
)
words
.
append
(
word
)
words
.
append
(
word
)
resp_sentences_seg
.
append
(
words
)
resp_sentences_seg
.
append
(
words
)
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
if
logprobs
:
if
all_probs
:
output_logits
=
output_logits
.
cpu
().
numpy
().
tolist
()
full_logits
=
full_logits
.
cpu
().
numpy
().
tolist
()
return
resp_sentences
,
resp_sentences_seg
,
output_logits
return
resp_sentences
,
resp_sentences_seg
,
output_logits
,
full_logits
,
decode_tokens
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
def
generate_samples_eval
(
model
,
context
,
max_gen_length
,
eos_token_id
):
"""
"""
...
@@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
...
@@ -260,9 +258,17 @@ def forward_step(model, tokens, position_ids, attention_mask, tokentype_ids,
return
output_tensor
return
output_tensor
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
def
sample_sequence_batch
(
model
,
attention_mask
,
position_ids
,
context_tokens
,
tokens_to_generate
,
all_probs
=
False
,
type_ids
=
None
,
temperature
=
None
):
context_lengths
,
attention_mask
,
position_ids
,
tokens_to_generate
,
logprobs
,
temperature
,
top_k
,
top_p
,
type_ids
=
None
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -330,8 +336,8 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
else
:
else
:
logits
=
logits
.
float
()
logits
=
logits
.
float
()
logits
/=
temperature
logits
/=
temperature
logits
=
top_k_logits
(
logits
,
top_k
=
args
.
top_k
,
logits
=
top_k_logits
(
logits
,
top_k
=
top_k
,
top_p
=
args
.
top_p
)
top_p
=
top_p
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
log_probs
=
F
.
softmax
(
logits
,
dim
=-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
prev
=
torch
.
multinomial
(
log_probs
,
num_samples
=
1
).
view
(
-
1
)
started
=
context_lengths
<=
context_length
started
=
context_lengths
<=
context_length
...
@@ -343,22 +349,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -343,22 +349,19 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens
=
switch
(
new_tokens
=
switch
(
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
].
view
(
-
1
),
prev
,
started
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
if
output_logits
is
None
:
if
logprobs
:
output_context
=
F
.
log_softmax
(
output
[:,
:
context_length
,
:],
2
)
if
output_logits
is
None
:
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:
context_length
+
1
],
2
)
output_context
=
F
.
log_softmax
(
output
[:,
:
context_length
,
:],
2
)
output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
indices
=
torch
.
unsqueeze
(
tokens
[:,
1
:
context_length
+
1
],
2
)
if
all_probs
:
output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
full_logits
=
output_context
else
:
else
:
output_context
=
F
.
log_softmax
(
output
,
2
)
output_context
=
F
.
log_softmax
(
output
,
2
)
indices
=
torch
.
unsqueeze
(
new_tokens
,
1
).
unsqueeze
(
2
)
indices
=
torch
.
unsqueeze
(
new_tokens
,
1
).
unsqueeze
(
2
)
new_output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
new_output_logits
=
torch
.
gather
(
output_context
,
2
,
indices
).
squeeze
(
2
)
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
# TODO(rprenger) we're copying output_logits every time. Should pre-allocate
output_logits
=
torch
.
cat
([
output_logits
,
new_output_logits
],
1
)
output_logits
=
torch
.
cat
([
output_logits
,
new_output_logits
],
1
)
if
all_probs
:
full_logits
=
torch
.
cat
([
full_logits
,
output_context
],
1
)
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_embedding_group
()
group
=
mpu
.
get_embedding_group
()
...
@@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -373,10 +376,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
group
=
mpu
.
get_pipeline_model_parallel_group
()
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
torch
.
distributed
.
broadcast
(
done
,
src
,
group
)
if
all_probs
:
yield
tokens
,
lengths
,
output_logits
yield
tokens
,
lengths
,
output_logits
,
full_logits
else
:
yield
tokens
,
lengths
,
output_logits
,
None
else
:
else
:
if
mpu
.
is_pipeline_first_stage
():
if
mpu
.
is_pipeline_first_stage
():
...
@@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -385,9 +385,9 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
new_tokens
=
torch
.
empty_like
(
tokens
[:,
context_length
])
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
torch
.
distributed
.
broadcast
(
new_tokens
,
src
,
group
)
tokens
[:,
context_length
]
=
new_tokens
tokens
[:,
context_length
]
=
new_tokens
yield
tokens
,
None
,
None
,
None
yield
tokens
,
None
,
None
else
:
else
:
yield
None
,
None
,
None
,
None
yield
None
,
None
,
None
done
=
torch
.
cuda
.
ByteTensor
([
0
])
done
=
torch
.
cuda
.
ByteTensor
([
0
])
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
src
=
mpu
.
get_pipeline_model_parallel_last_rank
()
...
...
tools/text_generation_cli.py
View file @
d46aa964
...
@@ -25,10 +25,10 @@ if __name__ == "__main__":
...
@@ -25,10 +25,10 @@ if __name__ == "__main__":
url
=
sys
.
argv
[
1
]
url
=
sys
.
argv
[
1
]
while
True
:
while
True
:
sentence
=
raw_input
(
"Enter prompt: "
)
sentence
=
raw_input
(
"Enter prompt: "
)
max_len
=
int
(
input
(
"Enter number tokens
output
: "
))
tokens_to_generate
=
int
(
input
(
"Enter number
of
tokens
to generate
: "
))
data
=
json
.
dumps
({
"
sentence
s"
:
[
sentence
],
"
max_len"
:
max_len
})
data
=
json
.
dumps
({
"
prompt
s"
:
[
sentence
],
"
tokens_to_generate"
:
tokens_to_generate
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
req
=
PutRequest
(
url
,
data
,
{
'Content-Type'
:
'application/json'
})
response
=
urllib2
.
urlopen
(
req
)
response
=
urllib2
.
urlopen
(
req
)
resp_sentences
=
json
.
load
(
response
)
resp_sentences
=
json
.
load
(
response
)
print
(
"Megatron Response: "
)
print
(
"Megatron Response: "
)
print
(
resp_sentences
[
"
sentences
"
][
0
])
print
(
resp_sentences
[
"
text
"
][
0
])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment