Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Megatron-LM
Commits
a6ba254f
Commit
a6ba254f
authored
Apr 02, 2020
by
Mohammad
Browse files
generate samples linted
parent
a19820b1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
72 additions
and
82 deletions
+72
-82
generate_samples.py
generate_samples.py
+72
-82
pretrain_gpt2.py
pretrain_gpt2.py
+0
-0
No files found.
generate_samples.py
100755 → 100644
View file @
a6ba254f
...
@@ -15,34 +15,27 @@
...
@@ -15,34 +15,27 @@
"""Sample Generate GPT2"""
"""Sample Generate GPT2"""
import
os
import
random
import
json
import
copy
import
copy
import
numpy
as
np
import
json
import
os
import
time
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
import
argparse
import
time
from
arguments
import
get_args
from
megatron.utils
import
Timers
from
megatron.utils
import
initialize_distributed
from
megatron.utils
import
set_random_seed
from
megatron.utils
import
get_ltor_masks_and_position_ids
from
megatron.utils
import
load_checkpoint
from
megatron.data_utils
import
make_tokenizer
from
configure_data
import
configure_data
from
megatron
import
mpu
from
megatron
.fp16
import
FP16_Module
from
megatron
import
get_args
from
megatron
.model
import
GPT2Model
from
megatron
import
get_tokenizer
from
megatron
.model
import
DistributedDataParallel
as
DDP
from
megatron
import
mpu
from
megatron
import
print_rank_0
from
megatron
import
print_rank_0
from
megatron.checkpointing
import
load_checkpoint
from
megatron.initialize
import
initialize_megatron
from
megatron.model
import
GPT2Model
from
megatron.training
import
get_model
from
megatron.utils
import
get_ltor_masks_and_position_ids
def
model_provider
():
def
model_provider
():
"""Build the model."""
"""Build the model."""
args
=
get_args
()
print_rank_0
(
'building GPT2 model ...'
)
print_rank_0
(
'building GPT2 model ...'
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
model
=
GPT2Model
(
num_tokentypes
=
0
,
parallel_output
=
False
)
...
@@ -56,7 +49,7 @@ def get_batch(context_tokens):
...
@@ -56,7 +49,7 @@ def get_batch(context_tokens):
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
# Move to GPU.
# Move to GPU.
tokens
=
context_tokens
.
view
(
args
.
batch_size
,
-
1
).
.
contiguous
().
cuda
()
tokens
=
context_tokens
.
view
(
args
.
batch_size
,
-
1
).
contiguous
().
cuda
()
# Get the attention mask and postition ids.
# Get the attention mask and postition ids.
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
attention_mask
,
_
,
position_ids
=
get_ltor_masks_and_position_ids
(
tokens
,
tokens
,
...
@@ -80,7 +73,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
...
@@ -80,7 +73,7 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
# last token of the top-k
# last token of the top-k
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
indices_to_remove
=
logits
<
torch
.
topk
(
logits
,
top_k
)[
0
][...,
-
1
,
None
]
logits
[
indices_to_remove
]
=
filter_value
logits
[
indices_to_remove
]
=
filter_value
if
top_p
>
0.0
:
if
top_p
>
0.0
:
# Cconvert to 1D
# Cconvert to 1D
sorted_logits
,
sorted_indices
=
torch
.
sort
(
sorted_logits
,
sorted_indices
=
torch
.
sort
(
...
@@ -98,12 +91,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
...
@@ -98,12 +91,12 @@ def top_k_logits(logits, top_k=0, top_p=0.0, filter_value=-float('Inf')):
for
i
in
range
(
sorted_indices
.
size
(
0
)):
for
i
in
range
(
sorted_indices
.
size
(
0
)):
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
indices_to_remove
=
sorted_indices
[
i
][
sorted_indices_to_remove
[
i
]]
logits
[
i
][
indices_to_remove
]
=
filter_value
logits
[
i
][
indices_to_remove
]
=
filter_value
return
logits
return
logits
def
generate_samples_input_from_file
(
model
):
def
generate_samples_input_from_file
(
model
):
"""XXX"""
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
...
@@ -118,15 +111,15 @@ def generate_samples_input_from_file(model):
...
@@ -118,15 +111,15 @@ def generate_samples_input_from_file(model):
if
args
.
sample_output_file
is
None
:
if
args
.
sample_output_file
is
None
:
sample_output_file
=
args
.
sample_input_file
+
".out"
sample_output_file
=
args
.
sample_input_file
+
".out"
print
(
'could not find `sample-output-file`, setting '
print
(
'could not find `sample-output-file`, setting '
'it to {}'
.
formatsample_output_file
(
))
'it to {}'
.
format
(
sample_output_file
))
fname_out
=
open
(
sample_output_file
,
"w+"
)
fname_out
=
open
(
sample_output_file
,
"w+"
)
context_count
=
0
context_count
=
0
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
while
True
:
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
raw_text
=
all_raw_text
[
input_pos
]
raw_text
=
all_raw_text
[
input_pos
]
...
@@ -148,7 +141,7 @@ def generate_samples_input_from_file(model):
...
@@ -148,7 +141,7 @@ def generate_samples_input_from_file(model):
else
:
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
mpu
.
get_model_parallel_src_rank
(),
...
@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model):
...
@@ -158,9 +151,8 @@ def generate_samples_input_from_file(model):
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
for
_
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
,
_
=
decode_tokens
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
decode_tokens
=
decode_tokens
[
0
].
cpu
().
numpy
().
tolist
()
...
@@ -176,24 +168,24 @@ def generate_samples_input_from_file(model):
...
@@ -176,24 +168,24 @@ def generate_samples_input_from_file(model):
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
"
\n\n
Megatron-LM:"
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
trim_decode_tokens
)
fname_out
.
write
(
"
\n
"
)
fname_out
.
write
(
"
\n
"
)
raw_text
=
None
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
context_count
+=
1
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
def
generate_samples_interactive
(
model
,
print_frequency
=
24
):
"""XXX"""
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
context_count
=
0
context_count
=
0
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
while
True
:
while
True
:
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
terminate_runs
=
0
terminate_runs
=
0
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
os
.
system
(
'clear'
)
os
.
system
(
'clear'
)
...
@@ -201,7 +193,7 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -201,7 +193,7 @@ def generate_samples_interactive(model, print_frequency=24):
while
not
raw_text
:
while
not
raw_text
:
print
(
'Prompt should not be empty!'
)
print
(
'Prompt should not be empty!'
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
raw_text
=
input
(
"
\n
Context prompt (stop to exit) >>> "
)
if
"stop"
in
raw_text
:
if
"stop"
in
raw_text
:
terminate_runs
=
1
terminate_runs
=
1
else
:
else
:
...
@@ -216,7 +208,7 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -216,7 +208,7 @@ def generate_samples_interactive(model, print_frequency=24):
else
:
else
:
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_tokens
=
tokenizer
.
tokenize
(
"EMPTY TEXT"
)
context_length
=
len
(
context_tokens
)
context_length
=
len
(
context_tokens
)
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
terminate_runs_tensor
=
torch
.
cuda
.
LongTensor
([
terminate_runs
])
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
torch
.
distributed
.
broadcast
(
terminate_runs_tensor
,
mpu
.
get_model_parallel_src_rank
(),
mpu
.
get_model_parallel_src_rank
(),
...
@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -226,7 +218,6 @@ def generate_samples_interactive(model, print_frequency=24):
if
terminate_runs
==
1
:
if
terminate_runs
==
1
:
return
return
start_time
=
time
.
time
()
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
token_stream
=
get_token_stream
(
model
,
[
context_tokens
])
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
for
counter
,
decode_tokens
in
enumerate
(
token_stream
):
decode_tokens
,
_
=
decode_tokens
decode_tokens
,
_
=
decode_tokens
...
@@ -250,20 +241,19 @@ def generate_samples_interactive(model, print_frequency=24):
...
@@ -250,20 +241,19 @@ def generate_samples_interactive(model, print_frequency=24):
raw_text
=
None
raw_text
=
None
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
torch
.
distributed
.
barrier
(
group
=
mpu
.
get_model_parallel_group
())
context_count
+=
1
context_count
+=
1
if
mpu
.
get_model_parallel_rank
()
==
0
:
if
mpu
.
get_model_parallel_rank
()
==
0
:
input
(
"
\n
Press any key to continue >>>"
)
input
(
"
\n
Press any key to continue >>>"
)
def
generate_samples_unconditional
(
model
):
def
generate_samples_unconditional
(
model
):
"""XXX"""
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
num_samples
=
args
.
num_samples
num_samples
=
args
.
num_samples
context_tokens
=
[[
tokenizer
.
eod
]
context_tokens
=
[[
tokenizer
.
eod
]
for
_
in
range
(
args
.
batch_size
)]
for
_
in
range
(
args
.
batch_size
)]
samples
=
[]
ctr
=
0
ctr
=
0
while
True
:
while
True
:
start_time
=
time
.
time
()
start_time
=
time
.
time
()
...
@@ -291,6 +281,7 @@ def generate_samples_unconditional(model):
...
@@ -291,6 +281,7 @@ def generate_samples_unconditional(model):
def
write_and_generate_samples_unconditional
(
model
):
def
write_and_generate_samples_unconditional
(
model
):
args
=
get_args
()
args
=
get_args
()
assert
args
.
genfile
is
not
None
assert
args
.
genfile
is
not
None
with
open
(
args
.
genfile
,
'w'
)
as
f
:
with
open
(
args
.
genfile
,
'w'
)
as
f
:
...
@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model):
...
@@ -298,8 +289,8 @@ def write_and_generate_samples_unconditional(model):
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
f
.
write
(
json
.
dumps
(
datum
)
+
'
\n
'
)
def
pad_batch
(
batch
,
tokenizer
,
args
):
def
pad_batch
(
batch
,
pad_id
,
args
):
pad_id
=
tokenizer
.
eod
context_lengths
=
[]
context_lengths
=
[]
for
tokens
in
batch
:
for
tokens
in
batch
:
context_length
=
len
(
tokens
)
context_length
=
len
(
tokens
)
...
@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args):
...
@@ -310,11 +301,12 @@ def pad_batch(batch, tokenizer, args):
def
get_token_stream
(
model
,
context_tokens
):
def
get_token_stream
(
model
,
context_tokens
):
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
pad_id
=
tokenizer
.
eod
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
context_tokens
,
context_lengths
=
pad_batch
(
context_tokens
,
tokenizer
,
args
)
tokenizer
.
eod
,
args
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_tokens_tensor
=
torch
.
cuda
.
LongTensor
(
context_tokens
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
context_length_tensor
=
torch
.
cuda
.
LongTensor
(
context_lengths
)
...
@@ -327,12 +319,7 @@ def get_token_stream(model, context_tokens):
...
@@ -327,12 +319,7 @@ def get_token_stream(model, context_tokens):
group
=
mpu
.
get_model_parallel_group
())
group
=
mpu
.
get_model_parallel_group
())
context_length
=
context_length_tensor
.
min
().
item
()
context_length
=
context_length_tensor
.
min
().
item
()
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
tokens
,
attention_mask
,
position_ids
=
get_batch
(
context_tokens_tensor
,
args
)
counter
=
0
org_context_length
=
context_length
layer_past
=
None
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
batch_token_iterator
=
sample_sequence_batch
(
model
,
context_tokens_tensor
,
context_length_tensor
,
context_length_tensor
,
...
@@ -343,21 +330,22 @@ def get_token_stream(model, context_tokens):
...
@@ -343,21 +330,22 @@ def get_token_stream(model, context_tokens):
def
switch
(
val1
,
val2
,
boolean
):
def
switch
(
val1
,
val2
,
boolean
):
boolean
=
boolean
.
type_as
(
val1
)
boolean
=
boolean
.
type_as
(
val1
)
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
return
(
1
-
boolean
)
*
val1
+
boolean
*
val2
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
def
sample_sequence_batch
(
model
,
context_tokens
,
context_lengths
,
attention_mask
,
position_ids
,
attention_mask
,
position_ids
,
maxlen
=
None
,
type_ids
=
None
):
maxlen
=
None
,
type_ids
=
None
):
"""XXX"""
args
=
get_args
()
args
=
get_args
()
tokenizer
=
get_tokenizer
()
tokenizer
=
get_tokenizer
()
model
.
eval
()
model
.
eval
()
with
torch
.
no_grad
():
with
torch
.
no_grad
():
context_length
=
context_lengths
.
min
().
item
()
context_length
=
context_lengths
.
min
().
item
()
eos_id
=
tokenizer
.
get_command
(
'eos'
).
I
d
eos_id
=
tokenizer
.
eo
d
counter
=
0
counter
=
0
org_context_length
=
context_length
org_context_length
=
context_length
...
@@ -372,7 +360,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -372,7 +360,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
maxlen
=
org_context_length
+
args
.
out_seq_length
maxlen
=
org_context_length
+
args
.
out_seq_length
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
lengths
=
torch
.
ones
([
batch_size
]).
long
().
cuda
()
*
maxlen
while
context_length
<=
(
maxlen
):
while
context_length
<=
(
maxlen
):
if
args
.
recompute
:
if
args
.
recompute
:
...
@@ -404,7 +392,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -404,7 +392,7 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
get_key_value
=
True
,
get_key_value
=
True
,
tokentype_ids
=
types2use
,
tokentype_ids
=
types2use
,
forward_method_parallel_output
=
False
)
forward_method_parallel_output
=
False
)
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
logits
=
logits
[:,
-
1
].
view
(
batch_size
,
-
1
).
contiguous
()
if
args
.
greedy
:
if
args
.
greedy
:
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
prev
=
torch
.
argmax
(
logits
,
dim
=-
1
).
view
(
-
1
)
...
@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -429,7 +417,6 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
done_token
=
(
prev
==
eos_id
).
byte
()
&
started
.
byte
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
just_finished
=
(
done_token
&
~
is_done
).
bool
()
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
lengths
[
just_finished
.
view
(
-
1
)]
=
context_length
was_done
=
is_done
is_done
=
is_done
|
done_token
is_done
=
is_done
|
done_token
done
=
torch
.
all
(
is_done
)
done
=
torch
.
all
(
is_done
)
...
@@ -438,56 +425,59 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
...
@@ -438,56 +425,59 @@ def sample_sequence_batch(model, context_tokens, context_lengths,
break
break
def
add_text_generate_args
(
parser
):
def
add_text_generate_args
(
parser
):
"""Text generate arguments."""
"""Text generation arguments."""
group
=
parser
.
add_argument_group
(
title
=
'text generation'
)
group
=
parser
.
add_argument_group
(
'Text generation'
,
'configurations'
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
)
group
.
add_argument
(
"--temperature"
,
type
=
float
,
default
=
1.0
,
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
)
help
=
'Sampling temperature.'
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
)
group
.
add_argument
(
"--greedy"
,
action
=
'store_true'
,
default
=
False
,
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
)
help
=
'Use greedy sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
)
group
.
add_argument
(
"--top_p"
,
type
=
float
,
default
=
0.0
,
help
=
'Top p sampling.'
)
group
.
add_argument
(
"--top_k"
,
type
=
int
,
default
=
0
,
help
=
'Top k sampling.'
)
group
.
add_argument
(
"--out-seq-length"
,
type
=
int
,
default
=
1024
,
help
=
'Size of the output generated text.'
)
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
"--sample-input-file"
,
type
=
str
,
default
=
None
,
help
=
'
g
et input from file instead of interactive mode, '
help
=
'
G
et input from file instead of interactive mode, '
'each line is an input'
)
'each line is an input
.
'
)
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
group
.
add_argument
(
"--sample-output-file"
,
type
=
str
,
default
=
None
,
help
=
'
o
utput file got from --sample-input-file'
)
help
=
'
O
utput file got from --sample-input-file'
)
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
group
.
add_argument
(
"--num-samples"
,
type
=
int
,
default
=
0
,
help
=
'
n
umber of samples to generate unconditionally, '
help
=
'
N
umber of samples to generate unconditionally, '
'defaults to 0 and interactive conditional sampling'
)
'defaults to 0 and interactive conditional sampling'
)
group
.
add_argument
(
"--genfile"
,
type
=
str
,
group
.
add_argument
(
"--genfile"
,
type
=
str
,
help
=
'
o
utput file when generating unconditionally'
)
help
=
'
O
utput file when generating unconditionally'
)
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
group
.
add_argument
(
"--recompute"
,
action
=
'store_true'
,
help
=
'
d
uring generation recompute all attention '
help
=
'
D
uring generation recompute all attention '
'instead of using previously computed keys/values.'
)
'instead of using previously computed keys/values.'
)
return
parser
return
parser
def
main
():
def
main
():
"""Main program."""
"""Main program."""
print
(
'Generate Samples'
)
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
initialize_megatron
(
extra_args_provider
=
add_text_generate_args
,
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
args_defaults
=
{
'tokenizer_type'
:
'GPT2BPETokenizer'
})
# Set up model and load checkpoint.
# Set up model and load checkpoint.
model
=
get_model
(
model_provider
)
model
=
get_model
(
model_provider
)
args
=
get_args
()
if
args
.
load
is
not
None
:
if
args
.
load
is
not
None
:
_
=
load_checkpoint
(
model
,
None
,
None
)
_
=
load_checkpoint
(
model
,
None
,
None
)
#
g
enerate samples
#
G
enerate samples
.
if
args
.
num_samples
==
0
:
if
args
.
num_samples
==
0
:
args
.
batch_size
=
1
assert
args
.
batch_size
=
=
1
if
args
.
sample_input_file
!=
""
:
if
args
.
sample_input_file
!=
""
:
generate_samples_input_from_file
(
model
)
generate_samples_input_from_file
(
model
)
else
:
else
:
generate_samples_interactive
(
model
)
generate_samples_interactive
(
model
)
else
:
else
:
write_and_generate_samples_unconditional
(
model
)
write_and_generate_samples_unconditional
(
model
)
if
__name__
==
"__main__"
:
main
()
if
__name__
==
"__main__"
:
main
()
pretrain_gpt2.py
100755 → 100644
View file @
a6ba254f
File mode changed from 100755 to 100644
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment