Unverified Commit a156e203 authored by saberkun's avatar saberkun Committed by GitHub
Browse files

Merged commit includes the following changes: (#7100)

254874613  by hongkuny<hongkuny@google.com>:

    Update glue tasks enum to match directory name

--
254866171  by taylorrobie<taylorrobie@google.com>:

    Internal change

PiperOrigin-RevId: 254874613
parent 240623ac
...@@ -40,8 +40,8 @@ flags.DEFINE_string( ...@@ -40,8 +40,8 @@ flags.DEFINE_string(
"The input data dir. Should contain the .tsv files (or other data files) " "The input data dir. Should contain the .tsv files (or other data files) "
"for the task.") "for the task.")
flags.DEFINE_enum("classification_task_name", "mnli", flags.DEFINE_enum("classification_task_name", "MNLI",
["cola", "mnli", "mrpc", "xnli"], ["COLA", "MNLI", "MRPC", "XNLI"],
"The name of the task to train BERT classifier.") "The name of the task to train BERT classifier.")
# BERT Squad task specific flags. # BERT Squad task specific flags.
......
...@@ -136,14 +136,24 @@ class AdamWeightDecay(tf.keras.optimizers.Adam): ...@@ -136,14 +136,24 @@ class AdamWeightDecay(tf.keras.optimizers.Adam):
def _resource_apply_dense(self, grad, var): def _resource_apply_dense(self, grad, var):
var_dtype = var.dtype.base_dtype var_dtype = var.dtype.base_dtype
try:
lr_t = self.apply_cache[var.device, var.dtype.base_dtype].lr_t
except AttributeError:
lr_t = self._decayed_lr_t[var_dtype] lr_t = self._decayed_lr_t[var_dtype]
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]): with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
return super(AdamWeightDecay, self)._resource_apply_dense( return super(AdamWeightDecay, self)._resource_apply_dense(
grad, var) grad, var)
def _resource_apply_sparse(self, grad, var, indices): def _resource_apply_sparse(self, grad, var, indices):
var_dtype = var.dtype.base_dtype var_dtype = var.dtype.base_dtype
try:
lr_t = self.apply_cache[var.device, var.dtype.base_dtype].lr_t
except AttributeError:
lr_t = self._decayed_lr_t[var_dtype] lr_t = self._decayed_lr_t[var_dtype]
with tf.control_dependencies([self._decay_weights_op(var, lr_t)]): with tf.control_dependencies([self._decay_weights_op(var, lr_t)]):
return super(AdamWeightDecay, self)._resource_apply_sparse( return super(AdamWeightDecay, self)._resource_apply_sparse(
grad, var, indices) grad, var, indices)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment