Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
ModelZoo
ResNet50_tensorflow
Commits
e77956d6
Commit
e77956d6
authored
Jun 11, 2021
by
A. Unique TensorFlower
Browse files
Internal change
PiperOrigin-RevId: 378911698
parent
bab477df
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
26 additions
and
17 deletions
+26
-17
official/nlp/data/sentence_prediction_dataloader.py
official/nlp/data/sentence_prediction_dataloader.py
+7
-5
official/nlp/data/sentence_prediction_dataloader_test.py
official/nlp/data/sentence_prediction_dataloader_test.py
+9
-6
official/nlp/tasks/sentence_prediction.py
official/nlp/tasks/sentence_prediction.py
+10
-6
No files found.
official/nlp/data/sentence_prediction_dataloader.py
View file @
e77956d6
...
@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
...
@@ -40,6 +40,7 @@ class SentencePredictionDataConfig(cfg.DataConfig):
label_type
:
str
=
'int'
label_type
:
str
=
'int'
# Whether to include the example id number.
# Whether to include the example id number.
include_example_id
:
bool
=
False
include_example_id
:
bool
=
False
label_field
:
str
=
'label_ids'
# Maps the key in TfExample to feature name.
# Maps the key in TfExample to feature name.
# E.g 'label_ids' to 'next_sentence_labels'
# E.g 'label_ids' to 'next_sentence_labels'
label_name
:
Optional
[
Tuple
[
str
,
str
]]
=
None
label_name
:
Optional
[
Tuple
[
str
,
str
]]
=
None
...
@@ -53,6 +54,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -53,6 +54,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
self
.
_params
=
params
self
.
_params
=
params
self
.
_seq_length
=
params
.
seq_length
self
.
_seq_length
=
params
.
seq_length
self
.
_include_example_id
=
params
.
include_example_id
self
.
_include_example_id
=
params
.
include_example_id
self
.
_label_field
=
params
.
label_field
if
params
.
label_name
:
if
params
.
label_name
:
self
.
_label_name_mapping
=
dict
([
params
.
label_name
])
self
.
_label_name_mapping
=
dict
([
params
.
label_name
])
else
:
else
:
...
@@ -65,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -65,7 +67,7 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'input_mask'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'segment_ids'
:
tf
.
io
.
FixedLenFeature
([
self
.
_seq_length
],
tf
.
int64
),
'
label_
ids'
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
self
.
_
label_
field
:
tf
.
io
.
FixedLenFeature
([],
label_type
),
}
}
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
name_to_features
[
'example_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
name_to_features
[
'example_id'
]
=
tf
.
io
.
FixedLenFeature
([],
tf
.
int64
)
...
@@ -92,10 +94,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
...
@@ -92,10 +94,10 @@ class SentencePredictionDataLoader(data_loader.DataLoader):
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
x
[
'example_id'
]
=
record
[
'example_id'
]
x
[
'example_id'
]
=
record
[
'example_id'
]
x
[
'
label_
ids'
]
=
record
[
'
label_
ids'
]
x
[
self
.
_
label_
field
]
=
record
[
self
.
_
label_
field
]
if
'
label_
ids'
in
self
.
_label_name_mapping
:
if
self
.
_
label_
field
in
self
.
_label_name_mapping
:
x
[
self
.
_label_name_mapping
[
'
label_
ids'
]]
=
record
[
'
label_
ids'
]
x
[
self
.
_label_name_mapping
[
self
.
_
label_
field
]]
=
record
[
self
.
_
label_
field
]
return
x
return
x
...
@@ -215,7 +217,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
...
@@ -215,7 +217,7 @@ class SentencePredictionTextDataLoader(data_loader.DataLoader):
model_inputs
=
self
.
_text_processor
(
segments
)
model_inputs
=
self
.
_text_processor
(
segments
)
if
self
.
_include_example_id
:
if
self
.
_include_example_id
:
model_inputs
[
'example_id'
]
=
record
[
'example_id'
]
model_inputs
[
'example_id'
]
=
record
[
'example_id'
]
model_inputs
[
'
label_
ids'
]
=
record
[
self
.
_label_field
]
model_inputs
[
self
.
_
label_
field
]
=
record
[
self
.
_label_field
]
return
model_inputs
return
model_inputs
def
_decode
(
self
,
record
:
tf
.
Tensor
):
def
_decode
(
self
,
record
:
tf
.
Tensor
):
...
...
official/nlp/data/sentence_prediction_dataloader_test.py
View file @
e77956d6
...
@@ -197,13 +197,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...
@@ -197,13 +197,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
vocab_file
=
vocab_file_path
)
vocab_file
=
vocab_file_path
)
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
label_field
=
data_config
.
label_field
self
.
assertCountEqual
(
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
'
label_
ids'
],
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_
field
],
features
.
keys
())
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'
label_
ids'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
label_
field
].
shape
,
(
batch_size
,))
@
parameterized
.
parameters
(
True
,
False
)
@
parameterized
.
parameters
(
True
,
False
)
def
test_python_sentencepiece_preprocessing
(
self
,
use_tfds
):
def
test_python_sentencepiece_preprocessing
(
self
,
use_tfds
):
...
@@ -231,13 +232,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...
@@ -231,13 +232,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
)
)
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
label_field
=
data_config
.
label_field
self
.
assertCountEqual
(
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
'
label_
ids'
],
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_
field
],
features
.
keys
())
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'
label_
ids'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
label_
field
].
shape
,
(
batch_size
,))
@
parameterized
.
parameters
(
True
,
False
)
@
parameterized
.
parameters
(
True
,
False
)
def
test_saved_model_preprocessing
(
self
,
use_tfds
):
def
test_saved_model_preprocessing
(
self
,
use_tfds
):
...
@@ -265,13 +267,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
...
@@ -265,13 +267,14 @@ class SentencePredictionTfdsDataLoaderTest(tf.test.TestCase,
)
)
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
dataset
=
loader
.
SentencePredictionTextDataLoader
(
data_config
).
load
()
features
=
next
(
iter
(
dataset
))
features
=
next
(
iter
(
dataset
))
label_field
=
data_config
.
label_field
self
.
assertCountEqual
(
self
.
assertCountEqual
(
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
'
label_
ids'
],
[
'input_word_ids'
,
'input_type_ids'
,
'input_mask'
,
label_
field
],
features
.
keys
())
features
.
keys
())
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_word_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_mask'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'input_type_ids'
].
shape
,
(
batch_size
,
seq_length
))
self
.
assertEqual
(
features
[
'
label_
ids'
].
shape
,
(
batch_size
,))
self
.
assertEqual
(
features
[
label_
field
].
shape
,
(
batch_size
,))
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
...
...
official/nlp/tasks/sentence_prediction.py
View file @
e77956d6
...
@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
...
@@ -69,6 +69,10 @@ class SentencePredictionTask(base_task.Task):
if
params
.
metric_type
not
in
METRIC_TYPES
:
if
params
.
metric_type
not
in
METRIC_TYPES
:
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
raise
ValueError
(
'Invalid metric_type: {}'
.
format
(
params
.
metric_type
))
self
.
metric_type
=
params
.
metric_type
self
.
metric_type
=
params
.
metric_type
if
hasattr
(
params
.
train_data
,
'label_field'
):
self
.
label_field
=
params
.
train_data
.
label_field
else
:
self
.
label_field
=
'label_ids'
def
build_model
(
self
):
def
build_model
(
self
):
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
if
self
.
task_config
.
hub_module_url
and
self
.
task_config
.
init_checkpoint
:
...
@@ -95,7 +99,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -95,7 +99,7 @@ class SentencePredictionTask(base_task.Task):
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
use_encoder_pooler
=
self
.
task_config
.
model
.
use_encoder_pooler
)
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
def
build_losses
(
self
,
labels
,
model_outputs
,
aux_losses
=
None
)
->
tf
.
Tensor
:
label_ids
=
labels
[
'
label_
ids'
]
label_ids
=
labels
[
self
.
label_
field
]
if
self
.
task_config
.
model
.
num_classes
==
1
:
if
self
.
task_config
.
model
.
num_classes
==
1
:
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
label_ids
,
model_outputs
)
loss
=
tf
.
keras
.
losses
.
mean_squared_error
(
label_ids
,
model_outputs
)
else
:
else
:
...
@@ -121,7 +125,7 @@ class SentencePredictionTask(base_task.Task):
...
@@ -121,7 +125,7 @@ class SentencePredictionTask(base_task.Task):
y
=
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
)
y
=
tf
.
zeros
((
1
,),
dtype
=
tf
.
float32
)
else
:
else
:
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
y
=
tf
.
zeros
((
1
,
1
),
dtype
=
tf
.
int32
)
x
[
'
label_
ids'
]
=
y
x
[
self
.
label_
field
]
=
y
return
x
return
x
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
dataset
=
tf
.
data
.
Dataset
.
range
(
1
)
...
@@ -144,10 +148,10 @@ class SentencePredictionTask(base_task.Task):
...
@@ -144,10 +148,10 @@ class SentencePredictionTask(base_task.Task):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
def
process_metrics
(
self
,
metrics
,
labels
,
model_outputs
):
for
metric
in
metrics
:
for
metric
in
metrics
:
metric
.
update_state
(
labels
[
'
label_
ids'
],
model_outputs
)
metric
.
update_state
(
labels
[
self
.
label_
field
],
model_outputs
)
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
def
process_compiled_metrics
(
self
,
compiled_metrics
,
labels
,
model_outputs
):
compiled_metrics
.
update_state
(
labels
,
model_outputs
)
compiled_metrics
.
update_state
(
labels
[
self
.
label_field
]
,
model_outputs
)
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
def
validation_step
(
self
,
inputs
,
model
:
tf
.
keras
.
Model
,
metrics
=
None
):
if
self
.
metric_type
==
'accuracy'
:
if
self
.
metric_type
==
'accuracy'
:
...
@@ -163,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
...
@@ -163,12 +167,12 @@ class SentencePredictionTask(base_task.Task):
'sentence_prediction'
:
# Ensure one prediction along batch dimension.
'sentence_prediction'
:
# Ensure one prediction along batch dimension.
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
tf
.
expand_dims
(
tf
.
math
.
argmax
(
outputs
,
axis
=
1
),
axis
=
1
),
'labels'
:
'labels'
:
labels
[
'
label_
ids'
],
labels
[
self
.
label_
field
],
})
})
if
self
.
metric_type
==
'pearson_spearman_corr'
:
if
self
.
metric_type
==
'pearson_spearman_corr'
:
logs
.
update
({
logs
.
update
({
'sentence_prediction'
:
outputs
,
'sentence_prediction'
:
outputs
,
'labels'
:
labels
[
'
label_
ids'
],
'labels'
:
labels
[
self
.
label_
field
],
})
})
return
logs
return
logs
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment