Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
453414da
Commit
453414da
authored
Jun 30, 2021
by
rprenger
Browse files
Removing unnecessary --recompute path
parent
f7fe3865
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
22 additions
and
36 deletions
+22
-36
megatron/text_generation_utils.py
megatron/text_generation_utils.py
+22
-32
tools/run_api_server.py
tools/run_api_server.py
+0
-4
No files found.
megatron/text_generation_utils.py
View file @
453414da
...
@@ -189,40 +189,30 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -189,40 +189,30 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
types2use
=
None
output
=
forward_step
(
model
,
tokens
,
if
counter
==
0
:
position_ids
,
tokens2use
=
tokens
[:,
:
context_length
]
attention_mask
,
positions2use
=
position_ids
[:,
:
context_length
]
tokentype_ids
=
type_ids
,
if
type_ids
is
not
None
:
forward_method_parallel_output
=
False
)
types2use
=
type_ids
[:,
:
context_length
]
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
context_length
-
1
,
:]
else
:
else
:
types2use
=
None
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
if
counter
==
0
:
batch_size
,
-
1
)
tokens2use
=
tokens
[:,
:
context_length
]
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
positions2use
=
position_ids
[:,
:
context_length
]
batch_size
,
-
1
)
if
type_ids
is
not
None
:
if
type_ids
is
not
None
:
types2use
=
type_ids
[:,
:
context_length
]
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
else
:
tokens2use
=
tokens
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
positions2use
=
position_ids
[:,
context_length
-
1
].
view
(
batch_size
,
-
1
)
batch_size
,
-
1
)
if
type_ids
is
not
None
:
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
types2use
=
type_ids
[:,
context_length
-
1
].
view
(
positions2use
,
batch_size
,
-
1
)
attention_mask
,
output
,
layer_past
=
forward_step
(
model
,
tokens2use
,
layer_past
=
layer_past
,
positions2use
,
get_key_value
=
True
,
attention_mask
,
tokentype_ids
=
types2use
,
layer_past
=
layer_past
,
forward_method_parallel_output
=
False
)
get_key_value
=
True
,
if
mpu
.
is_pipeline_last_stage
():
tokentype_ids
=
types2use
,
assert
output
is
not
None
forward_method_parallel_output
=
False
)
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
assert
output
is
not
None
logits
=
output
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
mpu
.
is_pipeline_last_stage
():
if
mpu
.
is_pipeline_last_stage
():
if
args
.
greedy
:
if
args
.
greedy
:
...
...
tools/run_api_server.py
View file @
453414da
...
@@ -55,10 +55,6 @@ def add_text_generate_args(parser):
...
@@ -55,10 +55,6 @@ def add_text_generate_args(parser):
help
=
'Top k sampling.'
)
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'During generation recompute all attention '
'instead of using previously computed keys/values.'
)
return
parser
return
parser
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment