Unverified Commit 8adb1b1d authored by Przemyslaw Tredak's avatar Przemyslaw Tredak Committed by GitHub
Browse files

Faster split of QKV for FlashAttention (#166)



* Faster split of QKV for FlashAttention
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* CI fixes
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* Fix
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>

* review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Message with assert
Co-authored-by: default avatarPrzemyslaw Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* Review comments
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* review
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* fix misalignment error
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

* make clarifying comment and check strides
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>

---------
Signed-off-by: default avatarPrzemek Tredak <ptredak@nvidia.com>
Signed-off-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: default avatarKirthi Shankar Sivamani <ksivamani@nvidia.com>
parent 4f3d6341
......@@ -1767,6 +1767,175 @@ bool userbuf_comm_available() { // TODO(ksivamani) check on python side
void placeholder() {} // TODO(ksivamani) clean this up
namespace flash_attention {
constexpr int warp_size = 32;
constexpr int type_size = 2; // FP16 or BF16
constexpr int nvec = sizeof(uint64_t) / type_size;
constexpr int load_size = warp_size * nvec;
constexpr int block_size = 512;
template <typename T>
__launch_bounds__(block_size)
__global__ void prepare_kernel_fwd(const T *qkvi,
T *qkv,
const size_t B,
const size_t S,
const size_t Z,
const size_t W) {
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = blockIdx.y * W + warpid * 3 * W * Z + id_in_warp * nvec;
const T *my_input = qkvi + offset_input;
const size_t s = warpid / B;
if (s >= S) return;
const size_t b = warpid % B;
const size_t offset_output = blockIdx.y * B * S * Z * W +
(s + b * S) * W * Z +
id_in_warp * nvec;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t*>(my_output + i * load_size);
*out = *reinterpret_cast<const uint64_t*>(my_input + i * load_size * 3);
}
}
template <typename T>
__launch_bounds__(block_size)
__global__ void prepare_kernel_bwd(const T *q, const T *k, const T *v,
T *qkv, const size_t B, const size_t S,
const size_t Z, const size_t W) {
const T *input = blockIdx.y == 0 ? q : (blockIdx.y == 1 ? k : v);
const int warpid = (blockDim.x * blockIdx.x + threadIdx.x) / warp_size;
const int id_in_warp = threadIdx.x % warp_size;
const size_t offset_input = warpid * W * Z + id_in_warp * nvec;
const T *my_input = input + offset_input;
const size_t b = warpid / S;
if (b >= B) return;
const size_t s = warpid % S;
const size_t offset_output = (b + s * B) * 3 * W * Z +
id_in_warp * nvec + blockIdx.y * W;
T *my_output = qkv + offset_output;
for (int i = 0; i < Z; ++i) {
uint64_t *out = reinterpret_cast<uint64_t*>(my_output + i * load_size * 3);
*out = *reinterpret_cast<const uint64_t*>(my_input + i * load_size);
}
}
} // namespace flash_attention
at::Tensor fa_prepare_fwd(at::Tensor qkvi) {
NVTE_CHECK(qkvi.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(qkvi.scalar_type() == at::ScalarType::Half ||
qkvi.scalar_type() == at::ScalarType::BFloat16);
NVTE_CHECK(qkvi.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(qkvi.size(3) == flash_attention::load_size);
NVTE_CHECK(qkvi.stride(3) == 1, "Wrong stride.");
NVTE_CHECK(qkvi.stride(2) == 3 * qkvi.size(3), "Wrong stride.");
NVTE_CHECK(qkvi.stride(1) == 3 * qkvi.size(3) * qkvi.size(2), "Wrong stride.");
NVTE_CHECK(qkvi.stride(0) == 3 * qkvi.size(3) * qkvi.size(2) * qkvi.size(1), "Wrong stride.");
// [s, b, n, h * 3] -> [3, b, s, n, h]
std::vector<int64_t> shape = {3, qkvi.size(1), qkvi.size(0), qkvi.size(2), qkvi.size(3)};
at::Tensor qkv = at::empty(shape, at::CUDA(qkvi.scalar_type()));
size_t warps = qkvi.size(0) * qkvi.size(1);
size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = flash_attention::block_size;
if (qkvi.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
flash_attention::prepare_kernel_fwd<dtype><<<grid, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
qkvi.data_ptr<dtype>(),
qkv.data_ptr<dtype>(),
shape[1],
shape[2],
shape[3],
shape[4]);
} else {
using dtype = at::BFloat16;
flash_attention::prepare_kernel_fwd<dtype><<<grid, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
qkvi.data_ptr<dtype>(),
qkv.data_ptr<dtype>(),
shape[1],
shape[2],
shape[3],
shape[4]);
}
return qkv;
}
at::Tensor fa_prepare_bwd(at::Tensor q, at::Tensor k, at::Tensor v) {
NVTE_CHECK(q.is_contiguous());
NVTE_CHECK(k.is_contiguous());
NVTE_CHECK(v.is_contiguous());
NVTE_CHECK(q.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(k.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(v.dim() == 4, "Expected 4-dim tensor.");
NVTE_CHECK(q.scalar_type() == at::ScalarType::Half ||
q.scalar_type() == at::ScalarType::BFloat16);
NVTE_CHECK(k.scalar_type() == q.scalar_type());
NVTE_CHECK(v.scalar_type() == q.scalar_type());
NVTE_CHECK(q.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(q.size(3) == flash_attention::load_size);
NVTE_CHECK(k.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(k.size(3) == flash_attention::load_size);
NVTE_CHECK(v.size(3) % flash_attention::load_size == 0);
NVTE_CHECK(v.size(3) == flash_attention::load_size);
// 3 x [s, b, n, h] -> [b, s, n, 3 * h]
std::vector<int64_t> shape = {q.size(1), q.size(0), q.size(2), 3 * q.size(3)};
at::Tensor qkv = at::empty(shape, at::CUDA(q.scalar_type()));
size_t warps = q.size(0) * q.size(1);
size_t warps_per_block = flash_attention::block_size / flash_attention::warp_size;
size_t blocks = (warps + warps_per_block - 1) / warps_per_block;
dim3 grid(blocks, 3);
int threads = flash_attention::block_size;
if (q.scalar_type() == at::ScalarType::Half) {
using dtype = at::Half;
flash_attention::prepare_kernel_bwd<dtype><<<grid, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
q.data_ptr<dtype>(),
k.data_ptr<dtype>(),
v.data_ptr<dtype>(),
qkv.data_ptr<dtype>(),
q.size(0),
q.size(1),
q.size(2),
q.size(3));
} else {
using dtype = at::BFloat16;
flash_attention::prepare_kernel_bwd<dtype><<<grid, threads, 0,
at::cuda::getCurrentCUDAStream()>>>(
q.data_ptr<dtype>(),
k.data_ptr<dtype>(),
v.data_ptr<dtype>(),
qkv.data_ptr<dtype>(),
q.size(0),
q.size(1),
q.size(2),
q.size(3));
}
return qkv;
}
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
// Softmax functions
......@@ -1812,6 +1981,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"Fused Attention FP8/BF16/FP16 BWD with packed KV");
m.def("fp8_transpose", &fp8_transpose, "Transpose with FP8 I/O");
m.def("fp8_gelu", &fp8_gelu, "GeLU with FP8 output");
m.def("fa_prepare_fwd", &fa_prepare_fwd, "Prepare QKV for Flash Attention");
m.def("fa_prepare_bwd", &fa_prepare_bwd, "Backward of QKV preparation for Flash Attention");
// Misc
m.def("get_cublasLt_version", &get_cublasLt_version, "Get cublasLt version");
......
......@@ -77,6 +77,48 @@ class DropPath(torch.nn.Module):
output = hidden_state.div(keep_prob) * random_tensor
return output
class _SplitLastDim(torch.autograd.Function):
""""""
@staticmethod
def forward(ctx,
mixed_x_layer: torch.Tensor,
num_parts: int
) -> Tuple[torch.Tensor, ...]:
return split_tensor_along_dim(mixed_x_layer, -1, num_parts)
@staticmethod
def backward(ctx,
*grad_outputs):
assert len(grad_outputs) > 0, "No gradients received for backprop!"
noop_ok = True
strides = grad_outputs[0].stride()
data_ptr = grad_outputs[0].untyped_storage().data_ptr()
shape = grad_outputs[0].shape
last_dim_size = grad_outputs[0].shape[-1]
for i, tensor in enumerate(grad_outputs):
if (tensor.stride() != strides or
tensor.shape != shape or
tensor.untyped_storage().data_ptr() != data_ptr or
tensor.storage_offset() != i * last_dim_size):
noop_ok = False
break
if noop_ok:
ret = torch.Tensor().to(grad_outputs[0].dtype)
ret = torch.Tensor().to(device=grad_outputs[0].device,
dtype=grad_outputs[0].dtype)
new_shape = list(shape)
new_shape[-1] = new_shape[-1] * len(grad_outputs)
ret.set_(grad_outputs[0].untyped_storage(),
grad_outputs[0].storage_offset(),
new_shape,
grad_outputs[0].stride()
)
return ret, None
return torch.cat(grad_outputs, dim = -1), None
class UnfusedDotProductAttention(torch.nn.Module):
"""Parallel attention w/o QKV and Proj Gemms
......@@ -204,6 +246,56 @@ class UnfusedDotProductAttention(torch.nn.Module):
return context_layer
class _PrepareQKVForFA(torch.autograd.Function):
"""This class converts QKV from interleaved (s, b, ...) layout
to separate contiguous q, k, v tensors in (b, s, ...) layout."""
@staticmethod
def forward(ctx,
query_layer: torch.Tensor,
key_layer: torch.Tensor,
value_layer: torch.Tensor
) -> torch.Tensor:
# All inputs received are non-contiguous tensors.
# The `query_layer` tensor is used to access the
# full memory region of the QKV tensor.
qkv = tex.fa_prepare_fwd(query_layer)
q, k, v = split_tensor_along_dim(qkv, 0, 3)
query_layer = torch.squeeze(q, 0)
key_layer = torch.squeeze(k, 0)
value_layer = torch.squeeze(v, 0)
return query_layer, key_layer, value_layer
@staticmethod
def backward(ctx,
dq: torch.Tensor,
dk: torch.Tensor,
dv: torch.Tensor
) -> Tuple[Union[torch.Tensor, None], ...]:
dqkv = tex.fa_prepare_bwd(dq, dk, dv)
dq, dk, dv = split_tensor_along_dim(dqkv, -1, 3)
return dq, dk, dv
def _check_if_interleaved(q, k, v):
data_ptr = q.untyped_storage().data_ptr()
check_ptrs = all(x.untyped_storage().data_ptr() == data_ptr for x in [q, k, v])
if not check_ptrs:
return False
stride = q.stride()
check_strides = all(stride == x.stride() for x in [q, k, v])
if not check_strides:
return False
shape = q.shape
check_shapes = all(shape == x.shape for x in [q, k, v])
if not check_shapes:
return False
last_dim_size = shape[-1]
check_offsets = all(i * last_dim_size == x.storage_offset()
for i, x in enumerate([q, k, v]))
return check_offsets
class FlashAttention(torch.nn.Module):
"""Dot product attention implementation by using the flash-attn package.
......@@ -252,8 +344,17 @@ class FlashAttention(torch.nn.Module):
attention_mask is None
), 'FlashAttention currently does not support external attention mask.'
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
# For now just 128, will make it more general in the future
if (query_layer.shape[-1] == 128 and
query_layer.shape[0] * query_layer.shape[1] >= 512 and
_check_if_interleaved(query_layer, key_layer, value_layer)):
query_layer, key_layer, value_layer = _PrepareQKVForFA.apply(query_layer,
key_layer,
value_layer)
else:
query_layer, key_layer, value_layer = [x.transpose(0,1).contiguous()
for x in (query_layer, key_layer, value_layer)]
batch_size, seqlen = query_layer.shape[0], query_layer.shape[1]
......@@ -731,9 +832,12 @@ class MultiHeadAttention(torch.nn.Module):
mixed_x_layer = mixed_x_layer.view(*new_tensor_shape)
# mixed_x_layer --> 3 [sq, b, np, hn]
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
)
if split_dim == -1 and not is_in_onnx_export_mode():
query_layer, key_layer, value_layer = _SplitLastDim.apply(mixed_x_layer, 3)
else:
query_layer, key_layer, value_layer = split_tensor_along_dim(
mixed_x_layer, split_dim, 3
)
else:
# Attention heads [sk, b, h] --> [sk, b, (np * 2 * hn)]
mixed_kv_layer = self.key_value(
......@@ -761,7 +865,10 @@ class MultiHeadAttention(torch.nn.Module):
mixed_kv_layer = mixed_kv_layer.view(*new_tensor_shape)
# mixed_kv_layer --> 2 [sk, b, np, hn]
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
if split_dim == -1 and not is_in_onnx_export_mode():
key_layer, value_layer = _SplitLastDim.apply(mixed_kv_layer, 2)
else:
key_layer, value_layer = split_tensor_along_dim(mixed_kv_layer, split_dim, 2)
# Attention head [sq, b, h] --> [sq, b, hp]
if self.input_layernorm:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment