"examples/vscode:/vscode.git/clone" did not exist on "181280babac02b0f8b9c61cdb3f89e8603e5f954"
Unverified Commit bb4cded1 authored by Xuechen Li's avatar Xuechen Li Committed by GitHub
Browse files

support when num_heads is not divisible by world_size; resolves #459 (#461)

* uneql rank.

* trim.

* enable passing in number of heads for each rank.

* simplify.

* simplify.

* cleanup.

* fix col parallel.

* fix bug with row parallel.

* fit out proj.

* refac.

* fix sharding logic.

* refac sharding.

* refac.

* support multiple of.

* make fn reuseable.

* fix bug in dimensions.

* scaffold.

* test uneven heads.

* fix test by adding barrier.

* refac.

* reuse code.

* clean up.
parent ada4710d
...@@ -27,7 +27,7 @@ from flash_attn.modules.mlp import ( ...@@ -27,7 +27,7 @@ from flash_attn.modules.mlp import (
ParallelMLP, ParallelMLP,
) )
from flash_attn.ops.activations import sqrelu_fwd from flash_attn.ops.activations import sqrelu_fwd
from flash_attn.utils.distributed import all_gather_raw, sync_shared_params from flash_attn.utils.distributed import all_gather_raw, sync_shared_params, get_dim_for_local_rank
from flash_attn.utils.generation import GenerationMixin from flash_attn.utils.generation import GenerationMixin
from flash_attn.utils.pretrained import state_dict_from_pretrained from flash_attn.utils.pretrained import state_dict_from_pretrained
from transformers import GPT2Config from transformers import GPT2Config
...@@ -62,7 +62,6 @@ try: ...@@ -62,7 +62,6 @@ try:
except ImportError: except ImportError:
FusedDenseSqreluDense = None FusedDenseSqreluDense = None
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
...@@ -681,41 +680,58 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -681,41 +680,58 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size inner_dim = config.n_inner if config.n_inner is not None else 4 * config.hidden_size
assert inner_dim % world_size == 0 assert inner_dim % world_size == 0
n_head = config.n_head
n_head_kv = getattr(config, "n_head_kv", n_head)
embed_dim = config.hidden_size
head_dim = embed_dim // n_head
def shard_first_dim(state_dict, key): def shard_first_dim(state_dict, key):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[0] // world_size dim = x.shape[0] // world_size
state_dict[key] = x[rank * dim : (rank + 1) * dim] state_dict[key] = x[rank * dim: (rank + 1) * dim]
def shard_last_dim(state_dict, key): def shard_last_dim(state_dict, key, multiple_of=1):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[-1] // world_size dim_each_rank = [
state_dict[key] = x[..., rank * dim : (rank + 1) * dim] get_dim_for_local_rank(x.size(-1), world_size, local_rank, multiple_of)
for local_rank in range(world_size)
]
beg, end = tuple(sum(dim_each_rank[:pos]) for pos in (rank, rank + 1))
state_dict[key] = x[..., beg:end]
def shard_gatedmlp_fc1_dim(state_dict, key): def shard_gatedmlp_fc1_dim(state_dict, key):
if key in state_dict: if key in state_dict:
x = state_dict[key] x = state_dict[key]
dim = x.shape[0] // world_size // 2 dim = x.shape[0] // world_size // 2
state_dict[key] = rearrange( state_dict[key] = rearrange(
rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim : (rank + 1) * dim], rearrange(x, "(two o) ... -> two o ...", two=2)[:, rank * dim: (rank + 1) * dim],
"two o ... -> (two o) ...", "two o ... -> (two o) ...",
) )
def shard_qkv_headdim(state_dict, key): def shard_qkv_headdim(state_dict, key):
if key in state_dict: if key in state_dict:
n_head = config.n_head n_head_each_rank = [
n_head_kv = getattr(config, "n_head_kv", n_head) get_dim_for_local_rank(n_head, world_size, local_rank) for local_rank in range(world_size)
assert n_head % world_size == 0 and n_head_kv % world_size == 0 ]
n_head_kv_each_rank = [
get_dim_for_local_rank(n_head_kv, world_size, local_rank) for local_rank in range(world_size)
]
beg_n_head = sum(n_head_each_rank[:rank])
end_n_head = sum(n_head_each_rank[: rank + 1])
beg_n_head_kv = sum(n_head_kv_each_rank[:rank])
end_n_head_kv = sum(n_head_kv_each_rank[: rank + 1])
if n_head_kv == n_head: if n_head_kv == n_head:
x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3) x = rearrange(state_dict[key], "(three d) ... -> three d ...", three=3)
dim = x.shape[1] // world_size
state_dict[key] = rearrange( state_dict[key] = rearrange(
x[:, rank * dim : (rank + 1) * dim], "three d ... -> (three d) ..." x[:, beg_n_head * head_dim : end_n_head * head_dim], "three d ... -> (three d) ..."
) )
else: else:
n_head_per_rank = n_head // world_size
n_head_kv_per_rank = n_head_kv // world_size
x = rearrange( x = rearrange(
state_dict[key], state_dict[key],
"(nheadqkv headdim) ... -> nheadqkv headdim ...", "(nheadqkv headdim) ... -> nheadqkv headdim ...",
...@@ -724,19 +740,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -724,19 +740,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
state_dict[key] = rearrange( state_dict[key] = rearrange(
torch.cat( torch.cat(
[ [
x[rank * n_head_per_rank : (rank + 1) * n_head_per_rank], x[beg_n_head:end_n_head],
x[ x[n_head + beg_n_head_kv: n_head + end_n_head_kv],
n_head x[n_head + n_head_kv + beg_n_head_kv: n_head + n_head_kv + end_n_head_kv],
+ rank * n_head_kv_per_rank : n_head
+ (rank + 1) * n_head_kv_per_rank
],
x[
n_head
+ n_head_kv
+ rank * n_head_kv_per_rank : n_head
+ n_head_kv
+ (rank + 1) * n_head_kv_per_rank
],
], ],
dim=0, dim=0,
), ),
...@@ -751,7 +757,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank): ...@@ -751,7 +757,9 @@ def shard_state_dict_tp(state_dict, config, world_size, rank):
for i in range(config.num_hidden_layers): for i in range(config.num_hidden_layers):
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight") shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.weight")
shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias") shard_qkv_headdim(state_dict, f"transformer.layers.{i}.mixer.Wqkv.bias")
shard_last_dim(state_dict, f"transformer.layers.{i}.mixer.out_proj.weight") shard_last_dim(
state_dict, f"transformer.layers.{i}.mixer.out_proj.weight", multiple_of=head_dim
)
if rank != 0: if rank != 0:
state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None) state_dict.pop(f"transformer.layers.{i}.mixer.out_proj.bias", None)
if config.activation_function in ["glu", "swiglu", "geglu"]: if config.activation_function in ["glu", "swiglu", "geglu"]:
...@@ -816,7 +824,7 @@ def combine_state_dicts_tp(state_dicts, config): ...@@ -816,7 +824,7 @@ def combine_state_dicts_tp(state_dicts, config):
torch.cat([x[:n_head_per_rank] for x in xs], dim=0), torch.cat([x[:n_head_per_rank] for x in xs], dim=0),
torch.cat( torch.cat(
[ [
x[n_head_per_rank : n_head_per_rank + n_head_kv_per_rank] x[n_head_per_rank: n_head_per_rank + n_head_kv_per_rank]
for x in xs for x in xs
], ],
dim=0, dim=0,
...@@ -922,6 +930,7 @@ def remap_state_dict_megatron(state_dict, config): ...@@ -922,6 +930,7 @@ def remap_state_dict_megatron(state_dict, config):
return key return key
state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items()) state_dict = OrderedDict((key_mapping_transformer(k), v) for k, v in state_dict.items())
# Word embedding and position embedding # Word embedding and position embedding
def key_mapping_pos_emb(key): def key_mapping_pos_emb(key):
return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key) return re.sub(r"^wpe.", "transformer.embeddings.position_embeddings.", key)
......
...@@ -5,9 +5,10 @@ from functools import partial ...@@ -5,9 +5,10 @@ from functools import partial
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from flash_attn.utils.distributed import get_dim_for_local_rank
try: try:
from flash_attn import ( from flash_attn import (
flash_attn_kvpacked_func, flash_attn_kvpacked_func,
...@@ -720,22 +721,21 @@ class ParallelMHA(nn.Module): ...@@ -720,22 +721,21 @@ class ParallelMHA(nn.Module):
self.use_flash_attn = use_flash_attn self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing self.checkpointing = checkpointing
self.process_group = process_group self.process_group = process_group
self.world_size = process_group.size() if process_group is not None else 1 self.world_size = process_group.size()
self.local_rank = torch.distributed.get_rank(process_group)
self.num_heads = num_heads self.num_heads = num_heads
assert self.embed_dim % self.num_heads == 0, "embed_dim must be divisible by num_heads"
self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads self.num_heads_kv = num_heads_kv if num_heads_kv is not None else num_heads
self.num_heads_per_rank = num_heads // self.world_size
self.num_heads_kv_per_rank = self.num_heads_kv // self.world_size
assert ( assert (
self.num_heads % self.num_heads_kv == 0 self.num_heads % self.num_heads_kv == 0
), "num_heads must be divisible by num_heads_kv" ), "num_heads must be divisible by num_heads_kv"
assert self.embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
assert ( self.num_heads_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
self.num_heads_kv % self.world_size == 0 self.num_heads_kv_per_rank = get_dim_for_local_rank(self.num_heads, self.world_size, self.local_rank)
), "num_heads_kv must be divisible by world_size"
self.head_dim = self.embed_dim // num_heads self.head_dim = self.embed_dim // num_heads
qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv) qkv_dim = self.head_dim * (self.num_heads + 2 * self.num_heads_kv)
kv_dim = 2 * self.head_dim * self.num_heads_kv
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, "rotary_emb is not installed" assert RotaryEmbedding is not None, "rotary_emb is not installed"
...@@ -755,6 +755,7 @@ class ParallelMHA(nn.Module): ...@@ -755,6 +755,7 @@ class ParallelMHA(nn.Module):
process_group, process_group,
bias=qkv_proj_bias, bias=qkv_proj_bias,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
multiple_of=self.head_dim * 3,
**factory_kwargs, **factory_kwargs,
) )
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
...@@ -771,6 +772,7 @@ class ParallelMHA(nn.Module): ...@@ -771,6 +772,7 @@ class ParallelMHA(nn.Module):
process_group, process_group,
bias=out_proj_bias, bias=out_proj_bias,
sequence_parallel=sequence_parallel, sequence_parallel=sequence_parallel,
multiple_of=self.head_dim,
**factory_kwargs, **factory_kwargs,
) )
......
...@@ -226,7 +226,7 @@ class RowParallelLinear(nn.Linear): ...@@ -226,7 +226,7 @@ class RowParallelLinear(nn.Linear):
local_multiple = div + int(torch.distributed.get_rank(process_group) < mod) local_multiple = div + int(torch.distributed.get_rank(process_group) < mod)
# Only rank 0 will have bias # Only rank 0 will have bias
super().__init__( super().__init__(
in_features // world_size, local_multiple * multiple_of,
out_features, out_features,
bias=bias and rank == 0, bias=bias and rank == 0,
device=device, device=device,
......
...@@ -125,3 +125,15 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc ...@@ -125,3 +125,15 @@ def allreduce_sequence_parallel_grad(model: torch.nn.Module, process_group: Proc
torch.distributed.all_reduce(coalesced, group=process_group) torch.distributed.all_reduce(coalesced, group=process_group)
for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)): for buf, synced in zip(grads, torch._utils._unflatten_dense_tensors(coalesced, grads)):
buf.copy_(synced) buf.copy_(synced)
def get_dim_for_local_rank(dim: int, world_size: int, local_rank: int, multiple_of: int = 1) -> int:
"""Get the dim for the local rank derived from splitting dim on world_size processes.
The split may not be even across the world_size processes.
"""
multiple = dim // multiple_of
div = multiple // world_size
mod = multiple % world_size
local_multiple = div + int(local_rank < mod)
return local_multiple * multiple_of
...@@ -16,7 +16,7 @@ import pytest ...@@ -16,7 +16,7 @@ import pytest
from einops import rearrange from einops import rearrange
from transformers import LlamaTokenizer from transformers import LlamaTokenizer, LlamaConfig
from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.models.llama.modeling_llama import LlamaForCausalLM
from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp from flash_attn.models.gpt import GPTLMHeadModel, combine_state_dicts_tp, shard_state_dict_tp
...@@ -255,7 +255,6 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -255,7 +255,6 @@ def test_llama_generation(model_name, checkpoint_format):
logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device) logits_ref = model_ref(out_hf.sequences).logits[:, (seqlen - 1):-1].to(device=device)
del model_ref del model_ref
pretrained_state_dict = _pretrained_state_dict_from_checkpoint( pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format checkpoint_path, model_name, config, checkpoint_format
) )
...@@ -297,8 +296,8 @@ def test_llama_generation(model_name, checkpoint_format): ...@@ -297,8 +296,8 @@ def test_llama_generation(model_name, checkpoint_format):
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}') print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error assert (logits_parallel - logits_ref).abs().max().item() < 2 * hf_error
assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error
...@@ -410,7 +409,101 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format): ...@@ -410,7 +409,101 @@ def test_llama_parallel_generation(model_name, world_size, checkpoint_format):
hf_error = (logits_hf - logits_ref).abs().max().item() hf_error = (logits_hf - logits_ref).abs().max().item()
print(f'HF fp16 logits max diff: {hf_error}') print(f'HF fp16 logits max diff: {hf_error}')
print(f'Logits max diff: {(logits - logits_ref).abs().max().item() }') print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * hf_error assert (logits - logits_ref).abs().max().item() < 2 * hf_error
print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item() }') print(f'Logits CG max diff: {(logits_cg - logits_ref).abs().max().item()}')
assert torch.equal(logits_cg, logits) assert torch.equal(logits_cg, logits)
@torch.no_grad()
@pytest.mark.parametrize('world_size', [2])
def test_llama_parallel_uneven_num_heads(world_size):
from apex.transformer import parallel_state
checkpoint_path = Path(os.environ.get('CHECKPOINT_DIR', current_dir.parent.parent / 'checkpoints')) / 'llama'
num_attention_heads = world_size + 1
model_name = f'teeny-{num_attention_heads}-heads'
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
process_group = parallel_state.get_tensor_model_parallel_group()
dtype = torch.float16
llama_config = LlamaConfig(
hidden_size=256 * num_attention_heads, # ParallelGatedMlp hidden_features must be divisible by 256
intermediate_size=256 * num_attention_heads * 4,
num_hidden_layers=4,
num_attention_heads=num_attention_heads,
initializer_range=0.5, # Set crazy init range so we don't have near zero weights implying a vacuous test.
)
config = llama_config_to_gpt2_config(llama_config)
config.use_flash_attn = True
config.fused_bias_fc = True
config.fused_mlp = False # We don't have fused GatedMLP yet
config.fused_dropout_add_ln = True
config.residual_in_fp32 = True
torch.manual_seed(0)
batch_size = 2
max_seqlen = 256
seqlens = torch.randint(max_seqlen // 2, max_seqlen + 1, (batch_size,), device=device)
input_ids = torch.randint(0, config.vocab_size, (batch_size, max_seqlen), dtype=torch.long,
device=device)
# Create a shared test model.
if rank == 0:
LlamaForCausalLM(config=llama_config).save_pretrained(checkpoint_path / f"{model_name}-hf")
torch.distributed.barrier()
# Run the standard forward pass test.
pretrained_state_dict = _pretrained_state_dict_from_checkpoint(
checkpoint_path, model_name, config, checkpoint_format="hf"
)
model = GPTLMHeadModel(config, process_group=process_group, device=device, dtype=dtype)
model.load_state_dict(shard_state_dict_tp(pretrained_state_dict, config, world_size, rank))
model.eval()
# TODO: Avoid duplicate code. Modularize the comparison of two forward pass diffs.
out = model.transformer(input_ids)
out, _ = all_gather_raw(out, process_group=process_group)
out = rearrange(out, "(b s) d -> b s d", b=batch_size)
logits = model(input_ids).logits
logits = rearrange(logits, "(b s) d -> b s d", b=batch_size)
logits, _ = all_gather_raw(logits, process_group)
logits = rearrange(logits, '(n b) ... d -> b ... (n d)', b=batch_size)
if rank == 0:
model_ref = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', device_map="auto"
)
model_ref.eval()
out_ref = model_ref.model(input_ids).last_hidden_state.to(device=device)
logits_ref = model_ref(input_ids).logits.to(device=device)
del model_ref
model_hf = LlamaForCausalLM.from_pretrained(
Path(checkpoint_path) / f'{model_name}-hf', torch_dtype=dtype, device_map="auto"
)
model_hf.eval()
out_hf = model_hf.model(input_ids).last_hidden_state.to(device=device)
logits_hf = model_hf(input_ids).logits.to(device=device)
del model_hf
print(f'Output max diff: {(out - out_ref).abs().max().item()}')
print(f'Output mean diff: {(out - out_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(out_hf - out_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(out_hf - out_ref).abs().mean().item()}')
assert (out - out_ref).abs().max().item() < 2 * (out_hf - out_ref).abs().max().item()
print(f'Logits max diff: {(logits - logits_ref).abs().max().item()}')
print(f'Logits mean diff: {(logits - logits_ref).abs().mean().item()}')
print(f'HF fp16 max diff: {(logits_hf - logits_ref).abs().max().item()}')
print(f'HF fp16 mean diff: {(logits_hf - logits_ref).abs().mean().item()}')
assert (logits - logits_ref).abs().max().item() < 2 * (logits_hf - logits_ref).abs().max().item()
import shutil
shutil.rmtree(checkpoint_path / f'{model_name}-hf')
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment