Unverified Commit 8bba5eeb authored by Ming-Xu Huang's avatar Ming-Xu Huang Committed by GitHub
Browse files

[JAX] Support various implementations of RoPE. (#655)



Support various implementations of RoPE and fix a coordinate representation bug
Signed-off-by: default avatarMing Huang <mingh@nvidia.com>
parent 82bc797f
......@@ -100,6 +100,7 @@ _KEY_OF_SCALE_ATTN_LOGITS = "scale_attn_logits"
_KEY_OF_NUM_HEADS = 'num_attention_heads'
_KEY_OF_NUM_GQA_GROUPS = 'num_gqa_groups'
_KEY_OF_ENABLE_ROPE = "enable_rotary_pos_emb"
_KEY_OF_ROPE_GROUP_METHOD = "rotary_pos_emb_group_method"
BASE_ATTRS = {
_KEY_OF_TRANSPOSE_BS: True,
......@@ -162,13 +163,29 @@ ATTRS = [{
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive"
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "consecutive"
}, {
_KEY_OF_TRANSPOSE_BS: False,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate"
}, {
_KEY_OF_TRANSPOSE_BS: True,
_KEY_OF_LAYERNORM_TYPE: 'layernorm',
_KEY_OF_DROPOUT_RATE: 0.0,
_KEY_OF_FUSE_MLP_WI: True,
_KEY_OF_ENABLE_ROPE: True,
_KEY_OF_ROPE_GROUP_METHOD: "alternate"
}]
ATTRS = [{**BASE_ATTRS, **attr} for attr in ATTRS]
......@@ -591,7 +608,7 @@ class TestDecoderLayer:
@pytest.mark.parametrize('attrs', ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs):
FP8Helper.finalize() # Ensure FP8 disabled.
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=2e-04)
self.forward_backward_runner(data_shape, dtype, attrs, rtol=1e-05, atol=3e-04)
@pytest.mark.skipif(not is_fp8_supported, reason=reason)
@pytest.mark.parametrize('data_shape', DATA_SHAPE)
......
......@@ -776,40 +776,74 @@ class MultiHeadAttnAttr:
ZERO_CEN = 'zero_centered_gamma'
NUM_ATTN_HEADS = 'num_attention_heads'
NUM_GQA_GROUPS = 'num_gqa_groups'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'padding'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
}, {
USE_BIAS: True,
LN_TYPE: 'rmsnorm',
ZERO_CEN: False,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate',
NUM_ATTN_HEADS: 8,
NUM_GQA_GROUPS: 4,
ATTN_MASK_TYPE: 'causal'
......@@ -839,6 +873,8 @@ class TestMultiHeadAttn(TestLayer):
input_layernorm = False
return_layernorm_output = False
attn_mask_type = attrs[MultiHeadAttnAttr.ATTN_MASK_TYPE]
enable_rotary_pos_emb = attrs[MultiHeadAttnAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[MultiHeadAttnAttr.ROPE_GROUP_METHOD]
fuse_qkv_params = True
transpose_batch_sequence = True
scale_attn_logits = False
......@@ -859,6 +895,8 @@ class TestMultiHeadAttn(TestLayer):
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
......@@ -878,6 +916,8 @@ class TestMultiHeadAttn(TestLayer):
return_layernorm_output=return_layernorm_output,
input_layernorm=input_layernorm,
attn_mask_type=attn_mask_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
fuse_qkv_params=fuse_qkv_params,
transpose_batch_sequence=transpose_batch_sequence,
scale_attn_logits=scale_attn_logits,
......@@ -920,6 +960,7 @@ class TransformerLayerAttr:
ZERO_CEN = 'zero_centered_gamma'
TRANSPOSE_BS = 'transpose_batch_sequence'
ENABLE_ROPE = 'enable_rotary_pos_emb'
ROPE_GROUP_METHOD = 'rotary_pos_emb_group_method'
ATTRS = [{
USE_BIAS: True,
LN_TYPE: 'layernorm',
......@@ -927,6 +968,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -935,6 +977,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -943,6 +986,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -951,6 +995,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -959,6 +1004,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -967,6 +1013,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -975,6 +1022,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -983,6 +1031,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -991,6 +1040,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -999,6 +1049,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1007,6 +1058,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -1015,6 +1067,7 @@ class TransformerLayerAttr:
ACTIVATION: ('relu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1023,6 +1076,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -1031,6 +1085,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1039,6 +1094,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -1047,6 +1103,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1055,6 +1112,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -1063,6 +1121,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1071,6 +1130,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: True
}, {
USE_BIAS: True,
......@@ -1079,6 +1139,25 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu', 'linear'),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: False,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
LN_TYPE: 'layernorm',
ZERO_CEN: True,
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'alternate',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1087,6 +1166,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.ENCODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}, {
USE_BIAS: True,
......@@ -1095,6 +1175,7 @@ class TransformerLayerAttr:
ACTIVATION: ('gelu',),
LYR_TYPE: TransformerLayerType.DECODER,
ENABLE_ROPE: True,
ROPE_GROUP_METHOD: 'consecutive',
TRANSPOSE_BS: False
}]
......@@ -1123,6 +1204,7 @@ class TestTransformer(TestLayer):
bias_init = WeightInit.Constant(0.0)
layer_type = attrs[TransformerLayerAttr.LYR_TYPE]
enable_rotary_pos_emb = attrs[TransformerLayerAttr.ENABLE_ROPE]
rotary_pos_emb_group_method = attrs[TransformerLayerAttr.ROPE_GROUP_METHOD]
enable_relative_embedding = True
relative_embedding = pax_fiddle.Config(RelativePositionBiases,
dtype=dtype,
......@@ -1160,6 +1242,7 @@ class TestTransformer(TestLayer):
layer_type=layer_type,
enable_relative_embedding=enable_relative_embedding,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
relative_embedding=relative_embedding,
drop_path=drop_path,
transpose_batch_sequence=transpose_batch_sequence)
......@@ -1182,6 +1265,7 @@ class TestTransformer(TestLayer):
"bias", bias_init),
layer_type=layer_type,
enable_rotary_pos_emb=enable_rotary_pos_emb,
rotary_pos_emb_group_method=rotary_pos_emb_group_method,
enable_relative_embedding=enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=drop_path,
......
......@@ -340,7 +340,7 @@ class MlpBlock(nn.Module):
return output
def apply_rotary_pos_emb(
def apply_rotary_pos_emb_alternate(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
......@@ -363,6 +363,41 @@ def apply_rotary_pos_emb(
return jnp.concatenate([first_part, second_part], axis=-1)
def apply_rotary_pos_emb_consecutive(
inputs: jnp.ndarray,
position: jnp.ndarray,
min_timescale: int = 1,
max_timescale: int = 10000,
):
embedding_dim = inputs.shape[-1]
half_embedding_dim = embedding_dim // 2
fraction = 2 * jnp.arange(0, half_embedding_dim) / embedding_dim
inputs_shifted_left = jnp.concatenate([inputs[..., 1:], inputs[..., :1]], axis=-1)
inputs_shifted_right = jnp.concatenate([inputs[..., -1:], inputs[..., :-1]], axis=-1)
inputs_shifted = jax.lax.select(
jnp.tile(
jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2),
inputs.shape[:-1] + (1,),
),
inputs_shifted_right,
inputs_shifted_left,
)
fraction = jnp.repeat(fraction, 2)
timescale = min_timescale * (max_timescale / min_timescale)**fraction
position = jnp.expand_dims(position, axis=tuple(range(2, inputs.ndim)))
sinusoid_inp = position / timescale
sin = jnp.sin(sinusoid_inp)
cos = jnp.cos(sinusoid_inp)
sign = jnp.sign(jnp.mod(jnp.arange(embedding_dim, dtype=jnp.int32), 2) - 0.5)
outputs = inputs * cos + inputs_shifted * sin * sign
return outputs
dynamic_vector_slice_in_dim = vmap(lax.dynamic_slice_in_dim, in_axes=(None, 0, None, None))
......@@ -392,6 +427,7 @@ class MultiHeadAttention(nn.Module):
scale_attn_logits: bool = False
scaled_query_init: bool = True
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv: bool = True
def __post_init__(self):
......@@ -512,6 +548,11 @@ class MultiHeadAttention(nn.Module):
position = jnp.expand_dims(jnp.arange(query.shape[seq_dim]), axis=batch_dim)
if self.rotary_pos_emb_group_method == 'alternate':
apply_rotary_pos_emb = apply_rotary_pos_emb_alternate
else:
apply_rotary_pos_emb = apply_rotary_pos_emb_consecutive
query = apply_rotary_pos_emb(query, position)
key = apply_rotary_pos_emb(key, position)
......@@ -836,6 +877,7 @@ class EncoderLayer(nn.Module):
output_layernorm: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -889,6 +931,7 @@ class EncoderLayer(nn.Module):
scaled_query_init=self.scaled_query_init,
fuse_qkv=self.fuse_qkv_params,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
name='attention')(x,
x,
encoder_mask,
......@@ -958,6 +1001,7 @@ class DecoderLayer(nn.Module):
zero_centered_gamma: bool = False
drop_path: float = 0.0
enable_rotary_pos_emb: bool = False
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv_params: bool = True
fuse_mlp_wi: bool = False
......@@ -1018,6 +1062,7 @@ class DecoderLayer(nn.Module):
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
name='self_attention')(x,
x,
......@@ -1052,6 +1097,7 @@ class DecoderLayer(nn.Module):
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv=self.fuse_qkv_params,
name='encoder_decoder_attention')(y,
encoded,
......
......@@ -541,19 +541,22 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method
return x
def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence: bool):
def rotary_pos_emb(x: Array,
windows: Tuple[int, int],
transpose_batch_sequence: bool,
group_method: str = 'consecutive'):
"""
Rotary Positional Embedding
x should be in shape of
[Batch, Seqlen, ..., Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Hidden] if transpose_batch_sequence is True.
[Batch, Seqlen, ..., Heads, Hidden] if transpose_batch_sequence is False, or
[Seqlen, Batch, ..., Heads, Hidden] if transpose_batch_sequence is True.
"""
embed_dim = x.shape[-1]
half_embed_dim = embed_dim // 2
hidden_dim = x.shape[-1]
half_hidden_dim = hidden_dim // 2
min_window = windows[0]
max_window = windows[1]
fraction = 2 * jnp.arange(0, half_embed_dim) / embed_dim
fraction = 2 * jnp.arange(0, half_hidden_dim) / hidden_dim
time_scales = min_window * (max_window / min_window)**fraction
time_scales = jnp.expand_dims(time_scales, axis=tuple(range(x.ndim - 1)))
......@@ -563,15 +566,55 @@ def rotary_pos_emb(x: Array, windows: Tuple[int, int], transpose_batch_sequence:
positions = jnp.expand_dims(jnp.arange(x.shape[seq_dim]), axis=batch_dim)
positions = jnp.expand_dims(positions, axis=tuple(range(2, x.ndim)))
sinusoidal_positions = positions / time_scales
sin = jnp.sin(sinusoidal_positions)
cos = jnp.cos(sinusoidal_positions)
def generate_sin_cos(timescales):
sinusoidal_positions = positions / timescales
sin = jnp.sin(sinusoidal_positions)
cos = jnp.cos(sinusoidal_positions)
return sin, cos
x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype)
def alternate_impl():
sin, cos = generate_sin_cos(time_scales)
return jnp.concatenate([part_1, part_2], axis=-1)
x1, x2 = jnp.split(x, 2, axis=-1)
part_1 = (x1 * cos - x2 * sin).astype(x.dtype)
part_2 = (x2 * cos + x1 * sin).astype(x.dtype)
output = jnp.concatenate([part_1, part_2], axis=-1)
return output
def consecutive_impl():
sin, cos = generate_sin_cos(jnp.repeat(time_scales, 2, axis=-1))
x_shifted_left = jnp.roll(x, -1, axis=-1)
x_shifted_right = jnp.roll(x, 1, axis=-1)
x_shifted = jax.lax.select(
jnp.tile(
jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2),
x.shape[:-1] + (1,),
),
x_shifted_right,
x_shifted_left,
)
sign = jnp.sign(jnp.mod(jnp.arange(hidden_dim, dtype=jnp.int32), 2) - 0.5)
output = x * cos + x_shifted * sin * sign
output = output.astype(x.dtype)
return output
def canonicalize_group_method(gm):
canonicalized_gm = gm.lower().strip().replace('-', '').replace('_', '')
assert canonicalized_gm in ['consecutive', 'alternate'], \
f"Invalid relative positional embedding group method. " \
f"Expect to be in []'alternate' or 'consecutive'], but got {gm}."
return canonicalized_gm
group_method = canonicalize_group_method(group_method)
if group_method == 'alternate':
return alternate_impl()
return consecutive_impl()
class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
......@@ -640,6 +683,10 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
num_heads: int, default = None
......@@ -693,6 +740,7 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
dtype: DType = jnp.float32
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
......@@ -942,9 +990,14 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods
else:
assert qkv_layout == QKVLayout.BSHD_BSHD_BSHD
# No changes to memory layout, should trigger bicast only (Ideally no Perf impact)
query = query.reshape((*query.shape[:2], self.num_attention_heads, self.head_dim))
key = key.reshape((*key.shape[:2], self.num_gqa_groups, self.head_dim))
query = rotary_pos_emb(query, self.rotary_pos_emb_windows,
self.transpose_batch_sequence)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence)
self.transpose_batch_sequence, self.rotary_pos_emb_group_method)
key = rotary_pos_emb(key, self.rotary_pos_emb_windows, self.transpose_batch_sequence,
self.rotary_pos_emb_group_method)
qkv_layout = QKVLayout.BSHD_BSHD_BSHD
if qkv_layout == QKVLayout.BSHD_BSHD_BSHD:
......@@ -1269,6 +1322,10 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
rotary_pos_emb_windows: Tuple[int, int], default = (1, 10000)
Indicate the min and max time-scales of rotary position embedding,
only used when :attr:`enable_rotary_pos_emb=True`
rotary_pos_emb_group_method: str, default = 'consecutive'
Indicate the method to coupled the coordinates. It should be one of
['consecutive', 'alternate']. 'alternate' is to pair index :math:`i` with :math:`i + d/2`
, d is the hidden dimension. 'consecutive' pairs index :math:`i` with :math:`i + 1`.
enable_sequence_parallel: bool, default = False
Whether to enable sequence parallelism to operations except dot.
......@@ -1323,6 +1380,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
relative_embedding: nn.Module = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
dtype: DType = jnp.float32
drop_path: float = 0.0
fuse_qkv_params: bool = True
......@@ -1464,6 +1522,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv_params=self.fuse_qkv_params,
kernel_init=self.mha_kernel_init,
use_bias=self.use_bias,
......@@ -1530,6 +1589,7 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods
attn_bias_type='no_bias',
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
float32_logits=self.float32_attention_logits,
scale_attn_logits=self.scale_attn_logits,
scaled_query_init=self.scaled_query_init,
......
......@@ -134,6 +134,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init: WeightInit = WeightInit.Constant(0.0)
attn_mask_type: str = 'causal'
attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
fuse_qkv_params: bool = True
transpose_batch_sequence: bool = True
enable_sequence_parallel: bool = False
......@@ -202,6 +205,9 @@ class MultiHeadAttention(TransformerEngineBaseLayer):
bias_init=TransformerEngineBaseLayer.generate_params_init("bias", self.bias_init),
attn_mask_type=self.attn_mask_type,
attn_bias_type=self.attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
fuse_qkv_params=self.fuse_qkv_params,
transpose_batch_sequence=self.transpose_batch_sequence,
enable_sequence_parallel=self.enable_sequence_parallel,
......@@ -255,6 +261,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
self_attn_bias_type: Optional[str] = None
enable_rotary_pos_emb: bool = False
rotary_pos_emb_windows: Tuple[int, int] = (1, 10000)
rotary_pos_emb_group_method: str = 'consecutive'
enable_relative_embedding: bool = True
relative_embedding: pax_fiddle.Config[RelativePositionBiases] = pax_fiddle.template_field(None)
drop_path: float = 0.0
......@@ -324,6 +331,7 @@ class TransformerLayer(TransformerEngineBaseLayer):
self_attn_bias_type=self.self_attn_bias_type,
enable_rotary_pos_emb=self.enable_rotary_pos_emb,
rotary_pos_emb_windows=self.rotary_pos_emb_windows,
rotary_pos_emb_group_method=self.rotary_pos_emb_group_method,
enable_relative_embedding=self.enable_relative_embedding,
relative_embedding=relative_embedding_flax_module,
drop_path=self.drop_path,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment