Commit e51692c9 authored by Hamish Tomlinson's avatar Hamish Tomlinson Committed by Copybara-Service
Browse files

Change softmax to use where and float32.

PiperOrigin-RevId: 519675443
Change-Id: If87e6d16189ddcc03bb8435308d37f5919353107
parent e1d2d53a
......@@ -32,6 +32,9 @@ import jax
import jax.numpy as jnp
_SOFTMAX_MASK = -1e9
def softmax_cross_entropy(logits, labels):
"""Computes softmax cross entropy given logits and one-hot class labels."""
loss = -jnp.sum(labels * jax.nn.log_softmax(logits), axis=-1)
......@@ -548,14 +551,14 @@ class Attention(hk.Module):
self.global_config = global_config
self.output_dim = output_dim
def __call__(self, q_data, m_data, bias, nonbatched_bias=None):
def __call__(self, q_data, m_data, mask, nonbatched_bias=None):
"""Builds Attention module.
Arguments:
q_data: A tensor of queries, shape [batch_size, N_queries, q_channels].
m_data: A tensor of memories from which the keys and values are
projected, shape [batch_size, N_keys, m_channels].
bias: A bias for the attention, shape [batch_size, N_queries, N_keys].
mask: A mask for the attention, shape [batch_size, N_queries, N_keys].
nonbatched_bias: Shared bias, shape [N_queries, N_keys].
Returns:
......@@ -586,10 +589,11 @@ class Attention(hk.Module):
q = jnp.einsum('bqa,ahc->bqhc', q_data, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ahc->bkhc', m_data, k_weights)
v = jnp.einsum('bka,ahc->bkhc', m_data, v_weights)
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k) + bias
logits = jnp.einsum('bqhc,bkhc->bhqk', q, k)
if nonbatched_bias is not None:
logits += jnp.expand_dims(nonbatched_bias, axis=0)
weights = jax.nn.softmax(logits)
logits = jnp.where(mask, logits, _SOFTMAX_MASK)
weights = utils.stable_softmax(logits)
weighted_avg = jnp.einsum('bhqk,bkhc->bqhc', weights, v)
if self.global_config.zero_init:
......@@ -686,9 +690,10 @@ class GlobalAttention(hk.Module):
q = jnp.einsum('ba,ahc->bhc', q_avg, q_weights) * key_dim**(-0.5)
k = jnp.einsum('bka,ac->bkc', m_data, k_weights)
bias = (1e9 * (q_mask[:, None, :, 0] - 1.))
logits = jnp.einsum('bhc,bkc->bhk', q, k) + bias
weights = jax.nn.softmax(logits)
bias = q_mask[:, None, :, 0]
logits = jnp.einsum('bhc,bkc->bhk', q, k)
logits = jnp.where(bias, logits, _SOFTMAX_MASK)
weights = utils.stable_softmax(logits)
weighted_avg = jnp.einsum('bhk,bkc->bhc', weights, v)
if self.global_config.zero_init:
......@@ -761,8 +766,8 @@ class MSARowAttentionWithPairBias(hk.Module):
assert len(msa_mask.shape) == 2
assert c.orientation == 'per_row'
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
mask = msa_mask[:, None, None, :]
assert len(mask.shape) == 4
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
......@@ -788,7 +793,7 @@ class MSARowAttentionWithPairBias(hk.Module):
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, bias],
batched_args=[msa_act, msa_act, mask],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
......@@ -829,8 +834,8 @@ class MSAColumnAttention(hk.Module):
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
mask = msa_mask[:, None, None, :]
assert len(mask.shape) == 4
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
......@@ -841,7 +846,7 @@ class MSAColumnAttention(hk.Module):
msa_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[msa_act, msa_act, bias],
batched_args=[msa_act, msa_act, mask],
nonbatched_args=[],
low_memory=not is_training)
......@@ -884,9 +889,6 @@ class MSAColumnGlobalAttention(hk.Module):
msa_act = jnp.swapaxes(msa_act, -2, -3)
msa_mask = jnp.swapaxes(msa_mask, -1, -2)
bias = (1e9 * (msa_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
msa_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
msa_act)
......@@ -941,8 +943,8 @@ class TriangleAttention(hk.Module):
pair_act = jnp.swapaxes(pair_act, -2, -3)
pair_mask = jnp.swapaxes(pair_mask, -1, -2)
bias = (1e9 * (pair_mask - 1.))[:, None, None, :]
assert len(bias.shape) == 4
mask = pair_mask[:, None, None, :]
assert len(mask.shape) == 4
pair_act = common_modules.LayerNorm(
axis=[-1], create_scale=True, create_offset=True, name='query_norm')(
......@@ -961,7 +963,7 @@ class TriangleAttention(hk.Module):
pair_act = mapping.inference_subbatch(
attn_mod,
self.global_config.subbatch_size,
batched_args=[pair_act, pair_act, bias],
batched_args=[pair_act, pair_act, mask],
nonbatched_args=[nonbatched_bias],
low_memory=not is_training)
......@@ -2171,11 +2173,11 @@ class TemplateEmbedding(hk.Module):
jnp.transpose(template_pair_representation, [1, 2, 0, 3]),
[num_res * num_res, num_templates, num_channels])
bias = (1e9 * (template_mask[None, None, None, :] - 1.))
mask = template_mask[None, None, None, :]
template_pointwise_attention_module = Attention(
self.config.attention, self.global_config, query_num_channels)
nonbatched_args = [bias]
nonbatched_args = [mask]
batched_args = [flat_query, flat_templates]
embedding = mapping.inference_subbatch(
......
......@@ -26,6 +26,20 @@ import jax.numpy as jnp
import numpy as np
def stable_softmax(logits: jax.Array) -> jax.Array:
"""Numerically stable softmax for (potential) bfloat 16."""
if logits.dtype == jnp.float32:
output = jax.nn.softmax(logits)
elif logits.dtype == jnp.bfloat16:
# Need to explicitly do softmax in float32 to avoid numerical issues
# with large negatives. Large negatives can occur if trying to mask
# by adding on large negative logits so that things softmax to zero.
output = jax.nn.softmax(logits.astype(jnp.float32)).astype(jnp.bfloat16)
else:
raise ValueError(f'Unexpected input dtype {logits.dtype}')
return output
def bfloat16_creator(next_creator, shape, dtype, init, context):
"""Creates float32 variables when bfloat16 is requested."""
if context.original_dtype == jnp.bfloat16:
......
......@@ -604,7 +604,6 @@
" pbar.set_description(f'Running {model_name}')\n",
"\n",
" cfg = config.model_config(model_name)\n",
" cfg.model.global_config.bfloat16 = False\n",
"\n",
" if model_type_to_use == ModelType.MONOMER:\n",
" cfg.data.eval.num_ensemble = 1\n",
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment