Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
4bbb9f2d
"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "a7bf77fc284810483f1e60afe34d1d27ad91ce2e"
Commit
4bbb9f2d
authored
Feb 08, 2019
by
thomwolf
Browse files
log loss - helpers
parent
5d7e8457
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
11 additions
and
10 deletions
+11
-10
examples/run_openai_gpt.py
examples/run_openai_gpt.py
+11
-10
No files found.
examples/run_openai_gpt.py
View file @
4bbb9f2d
...
@@ -100,18 +100,17 @@ def main():
...
@@ -100,18 +100,17 @@ def main():
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'--lm_coef'
,
type
=
float
,
default
=
0.5
)
parser
.
add_argument
(
'--n_valid'
,
type
=
int
,
default
=
374
)
parser
.
add_argument
(
'--n_valid'
,
type
=
int
,
default
=
374
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
'--server_ip'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
)
parser
.
add_argument
(
'--server_port'
,
type
=
str
,
default
=
''
,
help
=
"Can be used for distant debugging."
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print
(
args
)
print
(
args
)
# Some distant debugging
if
args
.
server_ip
and
args
.
server_port
:
# See https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
# Distant debugging - see https://code.visualstudio.com/docs/python/debugging#_attach-to-a-local-script
import
ptvsd
import
ptvsd
print
(
"Waiting for debugger attach"
)
print
(
"Waiting for debugger attach"
)
ptvsd
.
enable_attach
(
address
=
(
args
.
server_ip
,
args
.
server_port
),
redirect_output
=
True
)
ptvsd
.
enable_attach
(
address
=
(
args
.
server_ip
,
args
.
server_port
),
redirect_output
=
True
)
ptvsd
.
wait_for_attach
()
ptvsd
.
wait_for_attach
()
random
.
seed
(
args
.
seed
)
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
np
.
random
.
seed
(
args
.
seed
)
...
@@ -192,7 +191,8 @@ def main():
...
@@ -192,7 +191,8 @@ def main():
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
):
for
_
in
trange
(
int
(
args
.
num_train_epochs
),
desc
=
"Epoch"
):
tr_loss
=
0
tr_loss
=
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
nb_tr_examples
,
nb_tr_steps
=
0
,
0
for
step
,
batch
in
enumerate
(
tqdm
(
train_dataloader
,
desc
=
"Iteration"
)):
tqdm_bar
=
tqdm
(
train_dataloader
,
desc
=
"Training"
)
for
step
,
batch
in
enumerate
(
tqdm_bar
):
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
batch
=
tuple
(
t
.
to
(
device
)
for
t
in
batch
)
input_ids
,
mc_token_mask
,
lm_labels
,
mc_labels
=
batch
input_ids
,
mc_token_mask
,
lm_labels
,
mc_labels
=
batch
losses
=
model
(
input_ids
,
mc_token_mask
,
lm_labels
,
mc_labels
)
losses
=
model
(
input_ids
,
mc_token_mask
,
lm_labels
,
mc_labels
)
...
@@ -202,6 +202,7 @@ def main():
...
@@ -202,6 +202,7 @@ def main():
tr_loss
+=
loss
.
item
()
tr_loss
+=
loss
.
item
()
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_examples
+=
input_ids
.
size
(
0
)
nb_tr_steps
+=
1
nb_tr_steps
+=
1
tqdm_bar
.
desc
=
"Training loss: {:e.2}"
.
format
(
tr_loss
/
nb_tr_steps
)
# Save a trained model
# Save a trained model
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
model_to_save
=
model
.
module
if
hasattr
(
model
,
'module'
)
else
model
# Only save the model it-self
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment