Commit f96e254e authored by Josh Abramson's avatar Josh Abramson Committed by Copybara-Service
Browse files

Add `eval_dropout` option for using dropout in trunk at eval time.

PiperOrigin-RevId: 496885301
Change-Id: I42de2dd13784e2b358320349398a3fc88ee0d708
parent 4b726a2e
...@@ -383,6 +383,7 @@ CONFIG = ml_collections.ConfigDict({ ...@@ -383,6 +383,7 @@ CONFIG = ml_collections.ConfigDict({
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True, 'zero_init': True,
'eval_dropout': False,
}, },
'heads': { 'heads': {
'distogram': { 'distogram': {
...@@ -616,6 +617,7 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({ ...@@ -616,6 +617,7 @@ CONFIG_MULTIMER = ml_collections.ConfigDict({
'subbatch_size': 4, 'subbatch_size': 4,
'use_remat': False, 'use_remat': False,
'zero_init': True, 'zero_init': True,
'eval_dropout': False,
}, },
'heads': { 'heads': {
'distogram': { 'distogram': {
......
...@@ -76,6 +76,9 @@ def dropout_wrapper(module, ...@@ -76,6 +76,9 @@ def dropout_wrapper(module,
residual = module(input_act, mask, is_training=is_training, **kwargs) residual = module(input_act, mask, is_training=is_training, **kwargs)
dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate dropout_rate = 0.0 if gc.deterministic else module.config.dropout_rate
# Will override `is_training` to True if want to use dropout.
should_apply_dropout = True if gc.eval_dropout else is_training
if module.config.shared_dropout: if module.config.shared_dropout:
if module.config.orientation == 'per_row': if module.config.orientation == 'per_row':
broadcast_dim = 0 broadcast_dim = 0
...@@ -87,7 +90,7 @@ def dropout_wrapper(module, ...@@ -87,7 +90,7 @@ def dropout_wrapper(module,
residual = apply_dropout(tensor=residual, residual = apply_dropout(tensor=residual,
safe_key=safe_key, safe_key=safe_key,
rate=dropout_rate, rate=dropout_rate,
is_training=is_training, is_training=should_apply_dropout,
broadcast_dim=broadcast_dim) broadcast_dim=broadcast_dim)
new_act = output_act + residual new_act = output_act + residual
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment