Unverified Commit bc5d4c18 authored by Kirthi Shankar Sivamani's avatar Kirthi Shankar Sivamani Committed by GitHub
Browse files

Check input dimensions for Sequence Parallel (#208)



Check input dimensions for SP
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent e6bca031
......@@ -265,8 +265,9 @@ class TransformerLayer(torch.nn.Module):
if output_layer_init_method is None:
output_layer_init_method = get_default_init_method()
tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.sequence_parallel = (tp_size > 1) and sequence_parallel
self.tp_size = tp_size if tp_group is None else get_distributed_world_size(tp_group)
self.sequence_parallel = (self.tp_size > 1) and sequence_parallel
self.seq_length = seq_length
self.get_rng_state_tracker = get_rng_state_tracker
......@@ -282,7 +283,7 @@ class TransformerLayer(torch.nn.Module):
common_attention_kwargs = {
"layer_number": layer_number,
"tp_group": tp_group,
"tp_size": tp_size,
"tp_size": self.tp_size,
"fuse_wgrad_accumulation": fuse_wgrad_accumulation,
"get_rng_state_tracker": get_rng_state_tracker,
"sequence_parallel": self.sequence_parallel,
......@@ -326,7 +327,7 @@ class TransformerLayer(torch.nn.Module):
eps=layernorm_epsilon,
fuse_wgrad_accumulation=fuse_wgrad_accumulation,
tp_group=tp_group,
tp_size=tp_size,
tp_size=self.tp_size,
get_rng_state_tracker=get_rng_state_tracker,
init_method=init_method,
output_layer_init_method=output_layer_init_method,
......@@ -361,7 +362,7 @@ class TransformerLayer(torch.nn.Module):
set_jit_fusion_options()
if seq_length and micro_batch_size:
if self.sequence_parallel:
seq_length = seq_length // tp_size
seq_length = seq_length // self.tp_size
warmup_jit_bias_dropout_add_all_dtypes(
hidden_size, seq_length, micro_batch_size
)
......@@ -436,6 +437,11 @@ class TransformerLayer(torch.nn.Module):
hidden_states = hidden_states.contiguous()
if self.sequence_parallel and self.seq_length is not None:
assert (
hidden_states.shape[0] == self.seq_length // self.tp_size
), "Sequence dimension must be split across TP group when using sequence parallel."
if self.self_attn_mask_type != "causal" and attention_mask is not None:
assert (
attention_mask.dtype == torch.bool
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment