Commit 1e712ea8 authored by Tri Dao's avatar Tri Dao
Browse files

Implement TensorParallel for MHA

parent 226a1b72
#include <torch/extension.h> #include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#define CHECK_DEVICE(x) \ #define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA") #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), \
#x " must have shape (" #__VA_ARGS__ ")")
void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2, void apply_rotary_cuda(const torch::Tensor x1, const torch::Tensor x2,
const torch::Tensor cos, const torch::Tensor sin, const torch::Tensor cos, const torch::Tensor sin,
...@@ -26,6 +24,11 @@ void apply_rotary(const torch::Tensor x1, const torch::Tensor x2, ...@@ -26,6 +24,11 @@ void apply_rotary(const torch::Tensor x1, const torch::Tensor x2,
TORCH_CHECK(x1.sizes() == x2.sizes()); TORCH_CHECK(x1.sizes() == x2.sizes());
TORCH_CHECK(cos.sizes() == sin.sizes()); TORCH_CHECK(cos.sizes() == sin.sizes());
TORCH_CHECK(out1.sizes() == out2.sizes()); TORCH_CHECK(out1.sizes() == out2.sizes());
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at::cuda::CUDAGuard device_guard{(char)x1.get_device()};
apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj); apply_rotary_cuda(x1, x2, cos, sin, out1, out2, conj);
} }
......
...@@ -137,17 +137,19 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -137,17 +137,19 @@ class RotaryEmbedding(torch.nn.Module):
""" """
def __init__(self, dim: int, base=10000, scale_base=0, *_, **__): def __init__(self, dim: int, base=10000, scale_base=0, device=None):
""" """
If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554). If scale_base > 0, this implements XPos (Sun et al., https://arxiv.org/abs/2212.10554).
Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py Reference: https://github.com/sunyt32/torchscale/blob/main/torchscale/component/xpos_relative_position.py
""" """
super().__init__() super().__init__()
# Generate and save the inverse frequency buffer (non trainable) # Generate and save the inverse frequency buffer (non trainable)
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2, device=device,
dtype=torch.float32) / dim))
self.register_buffer("inv_freq", inv_freq) self.register_buffer("inv_freq", inv_freq)
self.scale_base = scale_base self.scale_base = scale_base
scale = (torch.arange(0, dim, 2) + 0.4 * dim) / (1.4 * dim) if scale_base > 0 else None scale = ((torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
/ (1.4 * dim) if scale_base > 0 else None)
self.register_buffer("scale", scale) self.register_buffer("scale", scale)
self._seq_len_cached = 0 self._seq_len_cached = 0
...@@ -168,14 +170,14 @@ class RotaryEmbedding(torch.nn.Module): ...@@ -168,14 +170,14 @@ class RotaryEmbedding(torch.nn.Module):
t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype) t = torch.arange(seqlen, device=x.device, dtype=self.inv_freq.dtype)
# Don't do einsum, it converts fp32 to fp16 # Don't do einsum, it converts fp32 to fp16
# freqs = torch.einsum("i,j->ij", t, self.inv_freq) # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
freqs = torch.outer(t, self.inv_freq) freqs = torch.outer(t, self.inv_freq.to(device=t.device))
if self.scale is None: if self.scale is None:
self._cos_cached = torch.cos(freqs).to(x.dtype) self._cos_cached = torch.cos(freqs).to(x.dtype)
self._sin_cached = torch.sin(freqs).to(x.dtype) self._sin_cached = torch.sin(freqs).to(x.dtype)
else: else:
power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device) power = ((torch.arange(seqlen, dtype=self.scale.dtype, device=self.scale.device)
- seqlen // 2) / self.scale_base) - seqlen // 2) / self.scale_base)
scale = self.scale ** rearrange(power, 's -> s 1') scale = self.scale.to(device=power.device) ** rearrange(power, 's -> s 1')
# We want the multiplication by scale to happen in fp32 # We want the multiplication by scale to happen in fp32
self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype) self._cos_cached = (torch.cos(freqs) * scale).to(x.dtype)
self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype) self._sin_cached = (torch.sin(freqs) * scale).to(x.dtype)
......
...@@ -21,9 +21,9 @@ except ImportError: ...@@ -21,9 +21,9 @@ except ImportError:
flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None flash_attn_qkvpacked_func, flash_attn_kvpacked_func = None, None
try: try:
from flash_attn.ops.fused_dense import FusedDense from flash_attn.ops.fused_dense import FusedDense, ColumnParallelLinear, RowParallelLinear
except ImportError: except ImportError:
FusedDense = None FusedDense, ColumnParallelLinear, RowParallelLinear = None, None, None
try: try:
from flash_attn.layers.rotary import RotaryEmbedding from flash_attn.layers.rotary import RotaryEmbedding
...@@ -42,7 +42,7 @@ class FlashSelfAttention(nn.Module): ...@@ -42,7 +42,7 @@ class FlashSelfAttention(nn.Module):
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
triton=False, device=None, dtype=None): triton=False):
super().__init__() super().__init__()
if attention_dropout != 0.0 or not triton: if attention_dropout != 0.0 or not triton:
assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed' assert flash_attn_unpadded_qkvpacked_func is not None, 'FlashAttention is not installed'
...@@ -109,7 +109,7 @@ class FlashCrossAttention(nn.Module): ...@@ -109,7 +109,7 @@ class FlashCrossAttention(nn.Module):
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0,
triton=False, device=None, dtype=None): triton=False):
super().__init__() super().__init__()
if attention_dropout != 0.0 or not triton: if attention_dropout != 0.0 or not triton:
assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed' assert flash_attn_unpadded_kvpacked_func is not None, 'FlashAttention is not installed'
...@@ -181,8 +181,7 @@ class SelfAttention(nn.Module): ...@@ -181,8 +181,7 @@ class SelfAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention attention_dropout: The dropout rate to apply to the attention
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
device=None, dtype=None):
super().__init__() super().__init__()
self.causal = causal self.causal = causal
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
...@@ -228,8 +227,7 @@ class CrossAttention(nn.Module): ...@@ -228,8 +227,7 @@ class CrossAttention(nn.Module):
attention_dropout: The dropout rate to apply to the attention attention_dropout: The dropout rate to apply to the attention
(default: 0.0) (default: 0.0)
""" """
def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0, def __init__(self, causal=False, softmax_scale=None, attention_dropout=0.0):
device=None, dtype=None):
super().__init__() super().__init__()
self.causal = causal self.causal = causal
self.softmax_scale = softmax_scale self.softmax_scale = softmax_scale
...@@ -309,7 +307,8 @@ class MHA(nn.Module): ...@@ -309,7 +307,8 @@ class MHA(nn.Module):
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet' assert not cross_attn, 'MHA with rotary embedding does not support cross-attention yet'
assert RotaryEmbedding is not None, 'rotary_emb is not installed' assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base) self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
device=device)
if fused_bias_fc and FusedDense is None: if fused_bias_fc and FusedDense is None:
raise ImportError('fused_dense is not installed') raise ImportError('fused_dense is not installed')
...@@ -338,7 +337,7 @@ class MHA(nn.Module): ...@@ -338,7 +337,7 @@ class MHA(nn.Module):
groups=2 * embed_dim) groups=2 * embed_dim)
inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention inner_attn_cls = FlashCrossAttention if use_flash_attn else CrossAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale, self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout, **factory_kwargs) attention_dropout=dropout)
# output projection always have the bias (for now) # output projection always have the bias (for now)
self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs) self.out_proj = linear_cls(embed_dim, embed_dim, **factory_kwargs)
...@@ -378,7 +377,7 @@ class MHA(nn.Module): ...@@ -378,7 +377,7 @@ class MHA(nn.Module):
if self.dwconv: if self.dwconv:
qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2], qkv = rearrange(self.dwconv_qkv(rearrange(qkv, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, h=self.num_heads) qkv = rearrange(qkv, '... (three h d) -> ... three h d', three=3, d=self.head_dim)
if self.rotary_emb_dim > 0: if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv) qkv = self.rotary_emb(qkv)
if not self.checkpointing: if not self.checkpointing:
...@@ -395,8 +394,8 @@ class MHA(nn.Module): ...@@ -395,8 +394,8 @@ class MHA(nn.Module):
else: else:
kv, x = self.Wkv(x) kv, x = self.Wkv(x)
q = self.Wq(x) q = self.Wq(x)
q = rearrange(q, '... (h d) -> ... h d', h=self.num_heads) q = rearrange(q, '... (h d) -> ... h d', d=self.head_dim)
kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, h=self.num_heads) kv = rearrange(kv, '... (two h d) -> ... two h d', two=2, d=self.head_dim)
if self.dwconv: if self.dwconv:
q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2], q = rearrange(self.dwconv_q(rearrange(q, 'b s d -> b d s'))[..., :-2],
'b d s -> b s d').contiguous() 'b d s -> b s d').contiguous()
...@@ -408,3 +407,66 @@ class MHA(nn.Module): ...@@ -408,3 +407,66 @@ class MHA(nn.Module):
context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs) context = torch.utils.checkpoint.checkpoint(self.inner_attn, q, kv, **kwargs)
out = self.out_proj(rearrange(context, '... h d -> ... (h d)')) out = self.out_proj(rearrange(context, '... h d -> ... (h d)'))
return out if not self.return_residual else (out, x) return out if not self.return_residual else (out, x)
class ParallelMHA(nn.Module):
"""Multi-head self-attention and cross-attention
"""
def __init__(self, embed_dim, num_heads, process_group, bias=True, dropout=0.0,
softmax_scale=None, causal=False, rotary_emb_dim=0, rotary_emb_scale_base=0,
use_flash_attn=False, checkpointing=False, device=None, dtype=None) -> None:
factory_kwargs = {'device': device, 'dtype': dtype}
super().__init__()
self.process_group = process_group
self.embed_dim = embed_dim
self.causal = causal
self.rotary_emb_dim = rotary_emb_dim
self.use_flash_attn = use_flash_attn
self.checkpointing = checkpointing
self.num_heads = num_heads
assert self.embed_dim % num_heads == 0, "self.kdim must be divisible by num_heads"
self.head_dim = self.embed_dim // num_heads
if self.rotary_emb_dim > 0:
assert RotaryEmbedding is not None, 'rotary_emb is not installed'
self.rotary_emb = RotaryEmbedding(self.rotary_emb_dim, scale_base=rotary_emb_scale_base,
device=device)
if ColumnParallelLinear is None or RowParallelLinear is None:
raise ImportError('fused_dense is not installed')
self.Wqkv = ColumnParallelLinear(embed_dim, 3 * embed_dim, process_group, bias=bias,
**factory_kwargs)
inner_attn_cls = FlashSelfAttention if use_flash_attn else SelfAttention
self.inner_attn = inner_attn_cls(causal=causal, softmax_scale=softmax_scale,
attention_dropout=dropout)
# output projection always have the bias (for now)
self.out_proj = RowParallelLinear(embed_dim, embed_dim, process_group, **factory_kwargs)
def forward(self, x, seqlen=None, **kwargs):
"""
Arguments:
x: (batch, seqlen, hidden_dim) (where hidden_dim = num heads * head dim) if seqlen=None.
If seqlen is not None, x is (batch * seqlen, hidden_dim). This is so that when we
split x during sequence parallel, we split the batch * seqlen dimension
(in case batch is small).
"""
qkv = self.Wqkv(x)
if seqlen is None:
qkv = rearrange(qkv, 'b s (three h d) -> b s three h d', three=3, d=self.head_dim)
else:
qkv = rearrange(qkv, '(b s) (three h d) -> b s three h d', s=seqlen, three=3,
d=self.head_dim)
if self.rotary_emb_dim > 0:
qkv = self.rotary_emb(qkv)
if not self.checkpointing:
context = self.inner_attn(qkv, **kwargs)
else:
context = torch.utils.checkpoint.checkpoint(self.inner_attn, qkv, **kwargs)
if seqlen is None:
context = rearrange(context, 'b s h d -> b s (h d)')
else:
context = rearrange(context, 'b s h d -> (b s) (h d)')
out = self.out_proj(context)
return out
# Run test with:
# torchrun --no_python --nproc_per_node=8 pytest -q -s tests/modules/test_mha_parallel.py
import math
import torch
import torch.nn.functional as F
import pytest
from einops import rearrange
from apex.transformer import parallel_state
from apex.transformer import tensor_parallel
from flash_attn.modules.mha import MHA, ParallelMHA
is_sm8x = torch.cuda.get_device_capability('cuda')[0] >= 8
@pytest.mark.parametrize('dtype', [torch.float16] + ([torch.bfloat16] if is_sm8x else []))
# @pytest.mark.parametrize('dtype', [torch.float16])
@pytest.mark.parametrize('world_size', [1, 2, 4, 8])
# @pytest.mark.parametrize('world_size', [2])
@pytest.mark.parametrize('head_dim', [64, 128])
# @pytest.mark.parametrize('head_dim', [64])
@pytest.mark.parametrize('embed_dim', [1024, 4096])
# @pytest.mark.parametrize('embed_dim', [1024])
def test_mha_parallel(embed_dim, head_dim, world_size, dtype):
assert embed_dim % head_dim == 0
num_heads = embed_dim // head_dim
assert num_heads % world_size == 0
rtol, atol = (3e-3, 1e-2) if dtype == torch.bfloat16 else (3e-3, 1e-3)
if not torch.distributed.is_initialized():
torch.distributed.init_process_group(backend='nccl', init_method='env://')
device = f'cuda:{torch.distributed.get_rank()}'
assert world_size <= torch.distributed.get_world_size()
parallel_state.initialize_model_parallel(tensor_model_parallel_size_=world_size)
rank = parallel_state.get_tensor_model_parallel_rank()
# set seed
torch.random.manual_seed(0)
batch_size = 8
seqlen = 1024
assert (batch_size * seqlen) % world_size == 0
x_pt = torch.randn(batch_size * seqlen, embed_dim, device=device, dtype=dtype,
requires_grad=True)
# We need to generate g here so that all processes get the same gradient,
# as rank 0 will have an extra bias that changes the RNG.
# If we don't divide by batch_size, the gradient gets a bit too large.
g = torch.randn_like(x_pt) / 32
x = tensor_parallel.scatter_to_sequence_parallel_region(x_pt).detach().clone().requires_grad_()
model_pt = MHA(embed_dim, num_heads, rotary_emb_dim=int(head_dim // 2),
use_flash_attn=True, device=device, dtype=dtype)
partition_dim = embed_dim // world_size
model = ParallelMHA(embed_dim, num_heads, parallel_state.get_tensor_model_parallel_group(),
rotary_emb_dim=int(head_dim // 2), use_flash_attn=True,
device=device, dtype=dtype)
with torch.no_grad():
model.Wqkv.weight.copy_(
rearrange(rearrange(model_pt.Wqkv.weight, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i')
)
model.Wqkv.bias.copy_(
rearrange(rearrange(model_pt.Wqkv.bias, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)')
)
model.out_proj.weight.copy_(
model_pt.out_proj.weight[:, rank * partition_dim:(rank + 1) * partition_dim]
)
if rank == 0:
model.out_proj.bias.copy_(model_pt.out_proj.bias)
out = model(x, seqlen=seqlen)
out_pt = rearrange(model_pt(rearrange(x_pt, '(b s) d -> b s d', s=seqlen)), 'b s d -> (b s) d')
partition_batch_dim = batch_size * seqlen // world_size
assert torch.allclose(
out, out_pt[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
rtol=rtol, atol=atol
)
out_pt.backward(g)
out.backward(g[rank * partition_batch_dim:(rank + 1) * partition_batch_dim])
parallel_state.destroy_model_parallel()
assert torch.allclose(
x.grad, x_pt.grad[rank * partition_batch_dim:(rank + 1) * partition_batch_dim],
rtol=rtol, atol=atol
)
# The error for d_weight and d_bias is quite a bit higher
assert torch.allclose(
model.Wqkv.weight.grad,
rearrange(rearrange(model_pt.Wqkv.weight.grad, '(three o) i -> three o i', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o i -> (three o) i'),
rtol=rtol, atol=atol * 10
)
assert torch.allclose(
model.Wqkv.bias.grad,
rearrange(rearrange(model_pt.Wqkv.bias.grad, '(three o) -> three o', three=3)[:, rank * partition_dim:(rank + 1) * partition_dim],
'three o -> (three o)'),
rtol=rtol, atol=atol * 5
)
assert torch.allclose(
model.out_proj.weight.grad,
model_pt.out_proj.weight.grad[:, rank * partition_dim:(rank + 1) * partition_dim],
rtol=rtol, atol=atol * 10
)
if rank == 0:
assert torch.allclose(model.out_proj.bias.grad, model_pt.out_proj.bias.grad, rtol=rtol, atol=atol * 5)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment