Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
fairscale
Commits
74181b08
Commit
74181b08
authored
Jul 17, 2020
by
Jun Ru Anderson
Committed by
Mandeep Singh Baines
Jul 31, 2020
Browse files
[feat] add Transformer gpipe benchmark
parent
0cd65242
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
134 additions
and
81 deletions
+134
-81
.isort.cfg
.isort.cfg
+1
-1
benchmarks/models/transformerModel.py
benchmarks/models/transformerModel.py
+0
-64
benchmarks/transformer.py
benchmarks/transformer.py
+132
-15
setup.cfg
setup.cfg
+1
-1
No files found.
.isort.cfg
View file @
74181b08
[settings]
known_third_party =
models,
pytest,setuptools,torch,torchtext
known_third_party =pytest,setuptools,torch,torchtext
benchmarks/models/transformerModel.py
deleted
100644 → 0
View file @
0cd65242
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
import
math
import
torch
import
torch.nn
as
nn
class
TransformerModel
(
nn
.
Module
):
def
__init__
(
self
,
ntoken
,
ninp
=
200
,
nhead
=
2
,
nhid
=
200
,
nlayers
=
2
,
dropout
=
0.5
):
super
(
TransformerModel
,
self
).
__init__
()
from
torch.nn
import
TransformerEncoder
,
TransformerEncoderLayer
self
.
model_type
=
"Transformer"
self
.
src_mask
=
None
self
.
pos_encoder
=
PositionalEncoding
(
ninp
,
dropout
)
encoder_layers
=
TransformerEncoderLayer
(
ninp
,
nhead
,
nhid
,
dropout
)
self
.
transformer_encoder
=
TransformerEncoder
(
encoder_layers
,
nlayers
)
self
.
encoder
=
nn
.
Embedding
(
ntoken
,
ninp
)
self
.
ninp
=
ninp
self
.
decoder
=
nn
.
Linear
(
ninp
,
ntoken
)
self
.
init_weights
()
def
_generate_square_subsequent_mask
(
self
,
sz
):
mask
=
(
torch
.
triu
(
torch
.
ones
(
sz
,
sz
))
==
1
).
transpose
(
0
,
1
)
mask
=
mask
.
float
().
masked_fill
(
mask
==
0
,
float
(
"-inf"
)).
masked_fill
(
mask
==
1
,
float
(
0.0
))
return
mask
def
init_weights
(
self
):
initrange
=
0.1
self
.
encoder
.
weight
.
data
.
uniform_
(
-
initrange
,
initrange
)
self
.
decoder
.
bias
.
data
.
zero_
()
self
.
decoder
.
weight
.
data
.
uniform_
(
-
initrange
,
initrange
)
def
forward
(
self
,
src
):
if
self
.
src_mask
is
None
or
self
.
src_mask
.
size
(
0
)
!=
len
(
src
):
device
=
src
.
device
mask
=
self
.
_generate_square_subsequent_mask
(
len
(
src
)).
to
(
device
)
self
.
src_mask
=
mask
src
=
self
.
encoder
(
src
)
*
math
.
sqrt
(
self
.
ninp
)
src
=
self
.
pos_encoder
(
src
)
output
=
self
.
transformer_encoder
(
src
,
self
.
src_mask
)
output
=
self
.
decoder
(
output
)
return
output
class
PositionalEncoding
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
dropout
=
0.1
,
max_len
=
5000
):
super
(
PositionalEncoding
,
self
).
__init__
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
pe
=
torch
.
zeros
(
max_len
,
d_model
)
position
=
torch
.
arange
(
0
,
max_len
,
dtype
=
torch
.
float
).
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
d_model
,
2
).
float
()
*
(
-
math
.
log
(
10000.0
)
/
d_model
))
pe
[:,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
).
transpose
(
0
,
1
)
self
.
register_buffer
(
"pe"
,
pe
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
pe
[:
x
.
size
(
0
),
:]
return
self
.
dropout
(
x
)
benchmarks/transformer.py
View file @
74181b08
...
...
@@ -3,12 +3,83 @@
import
math
import
time
from
models
import
transformerModel
as
transformer
import
torch
import
torch.nn
as
nn
import
torchtext
from
torchtext.data.utils
import
get_tokenizer
import
fairscale.nn.pipe.pipe
as
pipe
class
EmbeddingLayer
(
nn
.
Embedding
):
def
__init__
(
self
,
ntoken
,
ninp
,
initrange
):
super
().
__init__
(
ntoken
,
ninp
)
self
.
ninp
=
ninp
self
.
weight
.
data
.
uniform_
(
-
initrange
,
initrange
)
def
forward
(
self
,
src
):
return
super
().
forward
(
src
)
*
math
.
sqrt
(
self
.
ninp
)
class
PositionalEncodingLayer
(
nn
.
Module
):
def
__init__
(
self
,
d_model
,
dropout
=
0.1
,
max_len
=
5000
):
super
(
PositionalEncodingLayer
,
self
).
__init__
()
self
.
dropout
=
nn
.
Dropout
(
p
=
dropout
)
pe
=
torch
.
zeros
(
max_len
,
d_model
)
position
=
torch
.
arange
(
0
,
max_len
,
dtype
=
torch
.
float
).
unsqueeze
(
1
)
div_term
=
torch
.
exp
(
torch
.
arange
(
0
,
d_model
,
2
).
float
()
*
(
-
math
.
log
(
10000.0
)
/
d_model
))
pe
[:,
0
::
2
]
=
torch
.
sin
(
position
*
div_term
)
pe
[:,
1
::
2
]
=
torch
.
cos
(
position
*
div_term
)
pe
=
pe
.
unsqueeze
(
0
).
transpose
(
0
,
1
)
self
.
register_buffer
(
"pe"
,
pe
)
def
forward
(
self
,
x
):
x
=
x
+
self
.
pe
[:
x
.
size
(
0
),
:]
return
self
.
dropout
(
x
)
class
TransformerDecoderLayer
(
nn
.
TransformerEncoderLayer
):
"""Though this class inherits from torch.nn.TransformerEncoderLayer,
it functions as a decoder in this model"""
def
__init__
(
self
,
ninp
,
nhead
,
nhid
,
droupout
):
super
().
__init__
(
ninp
,
nhead
,
nhid
,
droupout
)
self
.
src_mask
=
None
def
_generate_square_subsequent_mask
(
self
,
sz
):
mask
=
(
torch
.
triu
(
torch
.
ones
(
sz
,
sz
))
==
1
).
transpose
(
0
,
1
)
mask
=
mask
.
float
().
masked_fill
(
mask
==
0
,
float
(
"-inf"
)).
masked_fill
(
mask
==
1
,
float
(
0.0
))
return
mask
def
forward
(
self
,
src
):
if
self
.
src_mask
is
None
or
self
.
src_mask
.
size
(
0
)
!=
len
(
src
):
device
=
src
.
device
mask
=
self
.
_generate_square_subsequent_mask
(
len
(
src
)).
to
(
device
)
self
.
src_mask
=
mask
return
super
().
forward
(
src
,
self
.
src_mask
)
class
LinearLayer
(
nn
.
Linear
):
def
__init__
(
self
,
ninp
,
ntoken
,
initrange
):
super
().
__init__
(
ninp
,
ntoken
)
self
.
bias
.
data
.
zero_
()
self
.
weight
.
data
.
uniform_
(
-
initrange
,
initrange
)
class
TransformerLMSequntial
(
nn
.
Sequential
):
"""A small language model based on the design of GPT-2 using nn.Sequeitnal
for compatability with Pipe"""
def
__init__
(
self
,
ntokens
,
ninp
,
nhead
,
nhid
,
dropout
,
initrange
):
super
(
TransformerLMSequntial
,
self
).
__init__
(
EmbeddingLayer
(
ntokens
,
ninp
,
initrange
),
PositionalEncodingLayer
(
ninp
,
dropout
),
TransformerDecoderLayer
(
ninp
,
nhead
,
nhid
,
dropout
),
LinearLayer
(
ninp
,
ntokens
,
initrange
),
)
def
get_data
(
device
):
TEXT
=
torchtext
.
data
.
Field
(
...
...
@@ -43,14 +114,16 @@ def get_batch(source, i, bptt):
def
make_model
(
device
,
ntokens
):
emsize
=
50
# embedding dimension
ninp
=
50
# embedding dimension
nhid
=
50
# the dimension of the feedforward network model in nn.TransformerEncoder
nlayers
=
1
# the number of nn.TransformerEncoderLayer in nn.TransformerEncoder
nhead
=
2
# the number of heads in the multiheadattention models
dropout
=
0.2
# the dropout value
model
=
transformer
.
TransformerModel
(
ntokens
,
emsize
,
nhead
,
nhid
,
nlayers
,
dropout
).
to
(
device
)
dropout
=
0
initrange
=
0.1
model
=
TransformerLMSequntial
(
ntokens
,
ninp
,
nhead
,
nhid
,
dropout
,
initrange
).
to
(
device
)
criterion
=
nn
.
CrossEntropyLoss
()
lr
=
5
.0
# learning rate
lr
=
1
.0
# learning rate
optimizer
=
torch
.
optim
.
SGD
(
model
.
parameters
(),
lr
=
lr
)
return
model
,
criterion
,
optimizer
...
...
@@ -64,9 +137,12 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
data
,
targets
=
get_batch
(
train_data
,
i
,
bptt
)
optimizer
.
zero_grad
()
output
=
model
(
data
)
output
=
output
.
to
(
targets
.
device
)
loss
=
criterion
(
output
.
view
(
-
1
,
ntokens
),
targets
)
loss
.
backward
()
torch
.
nn
.
utils
.
clip_grad_norm_
(
model
.
parameters
(),
0.5
)
torch
.
nn
.
utils
.
clip_grad_value_
(
model
.
parameters
(),
0.05
)
optimizer
.
step
()
total_loss
+=
loss
.
item
()
...
...
@@ -75,8 +151,9 @@ def train(train_data, model, criterion, optimizer, bptt, ntokens):
cur_loss
=
total_loss
/
log_interval
elapsed
=
time
.
time
()
-
start_time
print
(
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | loss {:5.2f} | ppl {:8.2f}"
.
format
(
batch
,
len
(
train_data
)
//
bptt
,
elapsed
*
1000
/
log_interval
,
cur_loss
,
math
.
exp
(
cur_loss
),
"| {:5d}/{:5d} batches | ms/batch {:5.2f} | "
"loss {:5.2f} | ppl {:8.2f}"
.
format
(
batch
,
len
(
train_data
)
//
bptt
,
elapsed
*
1000
/
log_interval
,
cur_loss
,
math
.
exp
(
cur_loss
)
)
)
total_loss
=
0
...
...
@@ -90,6 +167,7 @@ def evaluate(eval_model, data_source, criterion, bptt, ntokens):
for
i
in
range
(
0
,
data_source
.
size
(
0
)
-
1
,
bptt
):
data
,
targets
=
get_batch
(
data_source
,
i
,
bptt
)
output
=
eval_model
(
data
)
output
=
output
.
to
(
targets
.
device
)
output_flat
=
output
.
view
(
-
1
,
ntokens
)
total_loss
+=
len
(
data
)
*
criterion
(
output_flat
,
targets
).
item
()
return
total_loss
/
(
len
(
data_source
)
-
1
)
...
...
@@ -112,8 +190,9 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
val_loss
=
evaluate
(
model
,
val_data
,
criterion
,
bptt
,
ntokens
)
print
(
"-"
*
89
)
print
(
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} | "
"valid ppl {:8.2f}"
.
format
(
epoch
,
(
time
.
time
()
-
epoch_start_time
),
val_loss
,
math
.
exp
(
val_loss
))
"| end of epoch {:1d} | time: {:5.2f}s | valid loss {:5.2f} "
.
format
(
epoch
,
(
time
.
time
()
-
epoch_start_time
),
val_loss
)
)
print
(
"-"
*
89
)
...
...
@@ -124,16 +203,54 @@ def benchmark_language_model(train_data, val_data, test_data, model, criterion,
test_loss
=
evaluate
(
model
,
test_data
,
criterion
,
bptt
,
ntokens
)
print
(
"="
*
89
)
print
(
"| end of training | test loss {:5.2f}
| test ppl {:8.2f}
\n
| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}"
.
format
(
test_loss
,
math
.
exp
(
test_loss
),
elapsed_time
,
nwords
,
wps
"| end of training | test loss {:5.2f}
\n
| time: {:5.2f}s | words: {:3d} | wps: {:5.2f}"
.
format
(
test_loss
,
elapsed_time
,
nwords
,
wps
)
)
print
(
"="
*
89
)
if
len
(
model
.
balance
)
==
4
:
# Assert that words per second is within 3 standard deviations of the average
# of five golden runs
assert
wps
>
19779.5
-
(
3
*
167.81
)
print
(
"Peak allocated bytes on cuda:0: {:1d}"
.
format
(
torch
.
cuda
.
memory_stats
(
0
)[
"allocated_bytes.all.peak"
]))
print
(
"Peak allocated bytes on cuda:1: {:1d}"
.
format
(
torch
.
cuda
.
memory_stats
(
1
)[
"allocated_bytes.all.peak"
]))
print
(
"Peak allocated bytes on cuda:2: {:1d}"
.
format
(
torch
.
cuda
.
memory_stats
(
2
)[
"allocated_bytes.all.peak"
]))
print
(
"Peak allocated bytes on cuda:3: {:1d}"
.
format
(
torch
.
cuda
.
memory_stats
(
3
)[
"allocated_bytes.all.peak"
]))
# Assert that memory usage on each GPU is within 10% of golden run
# Right-hand-side is golden run KB * KB to bytes conversion * 110%
assert
torch
.
cuda
.
memory_stats
(
0
)[
"allocated_bytes.all.peak"
]
<
346094
*
1024
*
1.1
assert
torch
.
cuda
.
memory_stats
(
1
)[
"allocated_bytes.all.peak"
]
<
1251
*
1024
*
1.1
assert
torch
.
cuda
.
memory_stats
(
2
)[
"allocated_bytes.all.peak"
]
<
2595
*
1024
*
1.1
assert
torch
.
cuda
.
memory_stats
(
3
)[
"allocated_bytes.all.peak"
]
<
174784
*
1024
*
1.1
print
(
"No regression detected"
)
def
generate_balance
(
num_devices
,
num_layers
):
balance
=
[]
layers_assigned
=
0
for
i
in
range
(
num_devices
):
x
=
(
num_layers
-
layers_assigned
)
/
(
num_devices
-
i
)
if
x
.
is_integer
():
balance
.
append
(
int
(
x
))
layers_assigned
+=
x
else
:
balance
.
append
(
math
.
ceil
(
x
))
layers_assigned
+=
math
.
ceil
(
x
)
return
balance
if
__name__
==
"__main__"
:
assert
torch
.
cuda
.
is_available
()
num_devices
=
torch
.
cuda
.
device_count
()
assert
num_devices
>
0
torch
.
manual_seed
(
0
)
device
=
torch
.
device
(
"cuda"
)
ntokens
,
train_data
,
val_data
,
test_data
=
get_data
(
device
)
model
,
criterion
,
optimizer
=
make_model
(
device
,
ntokens
)
benchmark_language_model
(
train_data
,
val_data
,
test_data
,
model
,
criterion
,
optimizer
,
ntokens
)
balance
=
generate_balance
(
min
(
num_devices
,
4
),
len
(
model
))
p
=
pipe
.
Pipe
(
model
,
balance
)
benchmark_language_model
(
train_data
,
val_data
,
test_data
,
p
,
criterion
,
optimizer
,
ntokens
)
del
p
setup.cfg
View file @
74181b08
...
...
@@ -47,7 +47,7 @@ use_parentheses=True
skip_glob = build/*,stubs/*
# Don't split "import" and "from".
force_sort_within_sections = true
known_third_party =
models,
pytest,setuptools,torch,torchtext
known_third_party = pytest,setuptools,torch,torchtext
# -----------------------------------------------------------------------------
# mypy
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment