Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
transformers
Commits
b6992b7b
"git@developer.sourcefind.cn:chenpangpang/transformers.git" did not exist on "d24ea708d742263efe4f4b8d525402f2d916c96c"
Commit
b6992b7b
authored
Aug 31, 2019
by
LysandreJik
Browse files
Applied patch to OpenAI GPT, RoBERTa, TransfoL, XLM and XLNet
parent
bdb4409e
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
27 additions
and
41 deletions
+27
-41
pytorch_transformers/modeling_openai.py
pytorch_transformers/modeling_openai.py
+8
-17
pytorch_transformers/modeling_roberta.py
pytorch_transformers/modeling_roberta.py
+2
-2
pytorch_transformers/modeling_transfo_xl.py
pytorch_transformers/modeling_transfo_xl.py
+3
-6
pytorch_transformers/modeling_xlm.py
pytorch_transformers/modeling_xlm.py
+9
-8
pytorch_transformers/modeling_xlnet.py
pytorch_transformers/modeling_xlnet.py
+5
-8
No files found.
pytorch_transformers/modeling_openai.py
View file @
b6992b7b
...
@@ -249,14 +249,15 @@ class Attention(nn.Module):
...
@@ -249,14 +249,15 @@ class Attention(nn.Module):
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
c_proj
=
Conv1D
(
n_state
,
nx
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
attn_dropout
=
nn
.
Dropout
(
config
.
attn_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
resid_dropout
=
nn
.
Dropout
(
config
.
resid_pdrop
)
self
.
pruned_heads
=
[]
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
mask
=
torch
.
ones
(
self
.
n_head
,
self
.
split_size
//
self
.
n_head
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)
))
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -267,7 +268,7 @@ class Attention(nn.Module):
...
@@ -267,7 +268,7 @@ class Attention(nn.Module):
# Update hyper params
# Update hyper params
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
split_size
=
(
self
.
split_size
//
self
.
n_head
)
*
(
self
.
n_head
-
len
(
heads
))
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
n_head
=
self
.
n_head
-
len
(
heads
)
self
.
pruned_heads
.
extend
(
heads
)
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
def
_attn
(
self
,
q
,
k
,
v
,
head_mask
=
None
):
w
=
torch
.
matmul
(
q
,
k
)
w
=
torch
.
matmul
(
q
,
k
)
...
@@ -366,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
...
@@ -366,10 +367,7 @@ class OpenAIGPTPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_openai_gpt
load_tf_weights
=
load_tf_weights_in_openai_gpt
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
_init_weights
(
self
,
module
):
super
(
OpenAIGPTPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
,
Conv1D
)):
...
@@ -459,14 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
...
@@ -459,14 +457,7 @@ class OpenAIGPTModel(OpenAIGPTPreTrainedModel):
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
drop
=
nn
.
Dropout
(
config
.
embd_pdrop
)
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
self
.
h
=
nn
.
ModuleList
([
Block
(
config
.
n_ctx
,
config
,
scale
=
True
)
for
_
in
range
(
config
.
n_layer
)])
if
hasattr
(
config
,
"pruned_heads"
):
self
.
init_weights
()
pruned_heads
=
config
.
pruned_heads
.
copy
().
items
()
config
.
pruned_heads
=
{}
for
layer
,
heads
in
pruned_heads
:
if
self
.
h
[
int
(
layer
)].
attn
.
n_head
==
config
.
n_head
:
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
apply
(
self
.
init_weights
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
tokens_embed
=
self
.
_get_resized_embeddings
(
self
.
tokens_embed
,
new_num_tokens
)
self
.
tokens_embed
=
self
.
_get_resized_embeddings
(
self
.
tokens_embed
,
new_num_tokens
)
...
@@ -579,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
...
@@ -579,7 +570,7 @@ class OpenAIGPTLMHeadModel(OpenAIGPTPreTrainedModel):
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
transformer
=
OpenAIGPTModel
(
config
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
...
@@ -686,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
...
@@ -686,7 +677,7 @@ class OpenAIGPTDoubleHeadsModel(OpenAIGPTPreTrainedModel):
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
lm_head
=
nn
.
Linear
(
config
.
n_embd
,
config
.
vocab_size
,
bias
=
False
)
self
.
multiple_choice_head
=
SequenceSummary
(
config
)
self
.
multiple_choice_head
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
...
...
pytorch_transformers/modeling_roberta.py
View file @
b6992b7b
...
@@ -168,7 +168,7 @@ class RobertaModel(BertModel):
...
@@ -168,7 +168,7 @@ class RobertaModel(BertModel):
super
(
RobertaModel
,
self
).
__init__
(
config
)
super
(
RobertaModel
,
self
).
__init__
(
config
)
self
.
embeddings
=
RobertaEmbeddings
(
config
)
self
.
embeddings
=
RobertaEmbeddings
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
head_mask
=
None
):
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
if
input_ids
[:,
0
].
sum
().
item
()
!=
0
:
...
@@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
...
@@ -220,7 +220,7 @@ class RobertaForMaskedLM(BertPreTrainedModel):
self
.
roberta
=
RobertaModel
(
config
)
self
.
roberta
=
RobertaModel
(
config
)
self
.
lm_head
=
RobertaLMHead
(
config
)
self
.
lm_head
=
RobertaLMHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
...
...
pytorch_transformers/modeling_transfo_xl.py
View file @
b6992b7b
...
@@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -853,9 +853,6 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_transfo_xl
load_tf_weights
=
load_tf_weights_in_transfo_xl
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
TransfoXLPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
_init_weight
(
self
,
weight
):
def
_init_weight
(
self
,
weight
):
if
self
.
config
.
init
==
'uniform'
:
if
self
.
config
.
init
==
'uniform'
:
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
nn
.
init
.
uniform_
(
weight
,
-
self
.
config
.
init_range
,
self
.
config
.
init_range
)
...
@@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
...
@@ -865,7 +862,7 @@ class TransfoXLPreTrainedModel(PreTrainedModel):
def
_init_bias
(
self
,
bias
):
def
_init_bias
(
self
,
bias
):
nn
.
init
.
constant_
(
bias
,
0.0
)
nn
.
init
.
constant_
(
bias
,
0.0
)
def
init_weights
(
self
,
m
):
def
_
init_weights
(
self
,
m
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
classname
=
m
.
__class__
.
__name__
classname
=
m
.
__class__
.
__name__
...
@@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
...
@@ -1059,7 +1056,7 @@ class TransfoXLModel(TransfoXLPreTrainedModel):
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
r_emb
=
nn
.
Parameter
(
torch
.
FloatTensor
(
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
n_layer
,
self
.
max_klen
,
self
.
n_head
,
self
.
d_head
))
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
return
self
.
word_emb
return
self
.
word_emb
...
@@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
...
@@ -1306,7 +1303,7 @@ class TransfoXLLMHeadModel(TransfoXLPreTrainedModel):
else
:
else
:
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
self
.
crit
=
ProjectedAdaptiveLogSoftmax
(
config
.
n_token
,
config
.
d_embed
,
config
.
d_model
,
config
.
cutoffs
,
div_val
=
config
.
div_val
)
config
.
cutoffs
,
div_val
=
config
.
div_val
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
...
...
pytorch_transformers/modeling_xlm.py
View file @
b6992b7b
...
@@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module):
...
@@ -271,15 +271,16 @@ class MultiHeadAttention(nn.Module):
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
k_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
v_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
out_lin
=
nn
.
Linear
(
dim
,
dim
)
self
.
pruned_heads
=
[]
self
.
pruned_heads
=
set
()
def
prune_heads
(
self
,
heads
):
def
prune_heads
(
self
,
heads
):
attention_head_size
=
self
.
dim
//
self
.
n_heads
attention_head_size
=
self
.
dim
//
self
.
n_heads
if
len
(
heads
)
==
0
:
if
len
(
heads
)
==
0
:
return
return
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
mask
=
torch
.
ones
(
self
.
n_heads
,
attention_head_size
)
heads
=
set
(
heads
)
-
self
.
pruned_heads
for
head
in
heads
:
for
head
in
heads
:
head
-=
len
(
list
(
filter
(
lambda
h
:
h
<
head
,
self
.
pruned_heads
)
))
head
-=
sum
(
1
if
h
<
head
else
0
for
h
in
self
.
pruned_heads
)
mask
[
head
]
=
0
mask
[
head
]
=
0
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
mask
=
mask
.
view
(
-
1
).
contiguous
().
eq
(
1
)
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
index
=
torch
.
arange
(
len
(
mask
))[
mask
].
long
()
...
@@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module):
...
@@ -291,7 +292,7 @@ class MultiHeadAttention(nn.Module):
# Update hyper params
# Update hyper params
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
n_heads
=
self
.
n_heads
-
len
(
heads
)
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
dim
=
attention_head_size
*
self
.
n_heads
self
.
pruned_heads
.
extend
(
heads
)
self
.
pruned_heads
=
self
.
pruned_heads
.
union
(
heads
)
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
def
forward
(
self
,
input
,
mask
,
kv
=
None
,
cache
=
None
,
head_mask
=
None
):
"""
"""
...
@@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
...
@@ -386,7 +387,7 @@ class XLMPreTrainedModel(PreTrainedModel):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
super
(
XLMPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
def
_
init_weights
(
self
,
module
):
""" Initialize the weights. """
""" Initialize the weights. """
if
isinstance
(
module
,
nn
.
Embedding
):
if
isinstance
(
module
,
nn
.
Embedding
):
if
self
.
config
is
not
None
and
self
.
config
.
embed_init_std
is
not
None
:
if
self
.
config
is
not
None
and
self
.
config
.
embed_init_std
is
not
None
:
...
@@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel):
...
@@ -569,7 +570,7 @@ class XLMModel(XLMPreTrainedModel):
if
self
.
attentions
[
int
(
layer
)].
n_heads
==
config
.
n_heads
:
if
self
.
attentions
[
int
(
layer
)].
n_heads
==
config
.
n_heads
:
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
prune_heads
({
int
(
layer
):
list
(
map
(
int
,
heads
))})
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
embeddings
=
self
.
_get_resized_embeddings
(
self
.
embeddings
,
new_num_tokens
)
self
.
embeddings
=
self
.
_get_resized_embeddings
(
self
.
embeddings
,
new_num_tokens
)
...
@@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
...
@@ -781,7 +782,7 @@ class XLMWithLMHeadModel(XLMPreTrainedModel):
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
self
.
pred_layer
=
XLMPredLayer
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
...
@@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
...
@@ -843,7 +844,7 @@ class XLMForSequenceClassification(XLMPreTrainedModel):
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
attention_mask
=
None
,
cache
=
None
,
labels
=
None
,
head_mask
=
None
):
...
@@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
...
@@ -921,7 +922,7 @@ class XLMForQuestionAnswering(XLMPreTrainedModel):
self
.
transformer
=
XLMModel
(
config
)
self
.
transformer
=
XLMModel
(
config
)
self
.
qa_outputs
=
SQuADHead
(
config
)
self
.
qa_outputs
=
SQuADHead
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
def
forward
(
self
,
input_ids
,
lengths
=
None
,
position_ids
=
None
,
langs
=
None
,
token_type_ids
=
None
,
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
attention_mask
=
None
,
cache
=
None
,
start_positions
=
None
,
end_positions
=
None
,
...
...
pytorch_transformers/modeling_xlnet.py
View file @
b6992b7b
...
@@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
...
@@ -586,10 +586,7 @@ class XLNetPreTrainedModel(PreTrainedModel):
load_tf_weights
=
load_tf_weights_in_xlnet
load_tf_weights
=
load_tf_weights_in_xlnet
base_model_prefix
=
"transformer"
base_model_prefix
=
"transformer"
def
__init__
(
self
,
*
inputs
,
**
kwargs
):
def
_init_weights
(
self
,
module
):
super
(
XLNetPreTrainedModel
,
self
).
__init__
(
*
inputs
,
**
kwargs
)
def
init_weights
(
self
,
module
):
""" Initialize the weights.
""" Initialize the weights.
"""
"""
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
if
isinstance
(
module
,
(
nn
.
Linear
,
nn
.
Embedding
)):
...
@@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel):
...
@@ -736,7 +733,7 @@ class XLNetModel(XLNetPreTrainedModel):
self
.
layer
=
nn
.
ModuleList
([
XLNetLayer
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
layer
=
nn
.
ModuleList
([
XLNetLayer
(
config
)
for
_
in
range
(
config
.
n_layer
)])
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
dropout
=
nn
.
Dropout
(
config
.
dropout
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
def
_resize_token_embeddings
(
self
,
new_num_tokens
):
self
.
word_embedding
=
self
.
_get_resized_embeddings
(
self
.
word_embedding
,
new_num_tokens
)
self
.
word_embedding
=
self
.
_get_resized_embeddings
(
self
.
word_embedding
,
new_num_tokens
)
...
@@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
...
@@ -1037,7 +1034,7 @@ class XLNetLMHeadModel(XLNetPreTrainedModel):
self
.
transformer
=
XLNetModel
(
config
)
self
.
transformer
=
XLNetModel
(
config
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
lm_loss
=
nn
.
Linear
(
config
.
d_model
,
config
.
n_token
,
bias
=
True
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
self
.
tie_weights
()
self
.
tie_weights
()
def
tie_weights
(
self
):
def
tie_weights
(
self
):
...
@@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
...
@@ -1114,7 +1111,7 @@ class XLNetForSequenceClassification(XLNetPreTrainedModel):
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
sequence_summary
=
SequenceSummary
(
config
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
logits_proj
=
nn
.
Linear
(
config
.
d_model
,
config
.
num_labels
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
...
@@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
...
@@ -1216,7 +1213,7 @@ class XLNetForQuestionAnswering(XLNetPreTrainedModel):
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
end_logits
=
PoolerEndLogits
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
self
.
answer_class
=
PoolerAnswerClass
(
config
)
self
.
apply
(
self
.
init_weights
)
self
.
init_weights
(
)
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
def
forward
(
self
,
input_ids
,
token_type_ids
=
None
,
input_mask
=
None
,
attention_mask
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
mems
=
None
,
perm_mask
=
None
,
target_mapping
=
None
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment