Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
1b0af2d3
Unverified
Commit
1b0af2d3
authored
Sep 13, 2023
by
Casper
Committed by
GitHub
Sep 13, 2023
Browse files
Merge pull request #40 from casper-hansen/new_kernel
[NEW] GEMV kernel implementation
parents
84fb7e98
f264ebb3
Changes
34
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
4272 additions
and
116 deletions
+4272
-116
awq_cuda/attention/decoder_masked_multihead_attention.h
awq_cuda/attention/decoder_masked_multihead_attention.h
+184
-0
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+1608
-0
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
...cuda/attention/decoder_masked_multihead_attention_utils.h
+1786
-0
awq_cuda/attention/ft_attention.cpp
awq_cuda/attention/ft_attention.cpp
+182
-0
awq_cuda/attention/ft_attention.h
awq_cuda/attention/ft_attention.h
+15
-0
awq_cuda/position_embedding/pos_encoding.h
awq_cuda/position_embedding/pos_encoding.h
+2
-3
awq_cuda/position_embedding/pos_encoding_kernels.cu
awq_cuda/position_embedding/pos_encoding_kernels.cu
+51
-100
awq_cuda/pybind.cpp
awq_cuda/pybind.cpp
+9
-2
awq_cuda/quantization/gemm_cuda_gen.cu
awq_cuda/quantization/gemm_cuda_gen.cu
+4
-2
awq_cuda/quantization/gemv_cuda.cu
awq_cuda/quantization/gemv_cuda.cu
+249
-0
awq_cuda/quantization/gemv_cuda.h
awq_cuda/quantization/gemv_cuda.h
+9
-0
examples/basic_quant.py
examples/basic_quant.py
+1
-1
examples/benchmark.py
examples/benchmark.py
+132
-0
setup.py
setup.py
+40
-8
No files found.
awq_cuda/attention/decoder_masked_multihead_attention.h
0 → 100644
View file @
1b0af2d3
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template
<
typename
T
>
struct
Multihead_attention_params_base
{
// The output buffer. Dimensions B x D.
T
*
out
=
nullptr
;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const
T
*
q
=
nullptr
,
*
q_bias
=
nullptr
;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const
T
*
k
=
nullptr
,
*
k_bias
=
nullptr
;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const
T
*
v
=
nullptr
,
*
v_bias
=
nullptr
;
// The cache for the Ks. The size must be at least B x L x D.
T
*
k_cache
=
nullptr
;
// The cache for the Vs. The size must be at least B x L x D.
T
*
v_cache
=
nullptr
;
// The indirections to use for cache when beam sampling.
const
int
*
cache_indir
=
nullptr
;
// Stride to handle the case when KQV is a single buffer
int
stride
=
0
;
// The batch size.
int
batch_size
=
0
;
// The beam width
int
beam_width
=
0
;
// The sequence length.
int
memory_max_len
=
0
;
// The number of heads (H).
int
num_heads
=
0
;
// The number of heads for KV cache.
int
num_kv_heads
=
0
;
// The hidden dimension per head (Dh).
int
hidden_size_per_head
=
0
;
// The per-head latent space reserved for rotary embeddings.
int
rotary_embedding_dim
=
0
;
bool
neox_rotary_style
=
false
;
float
rotary_base
=
0.0
f
;
// The maximum length of input sentences.
int
max_input_length
=
0
;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int
timestep
=
0
;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float
inv_sqrt_dh
=
0.0
f
;
// Used when we have some input context like gpt
const
int
*
total_padding_tokens
=
nullptr
;
const
bool
*
masked_tokens
=
nullptr
;
const
int
*
prefix_prompt_lengths
=
nullptr
;
int
max_prefix_prompt_length
=
0
;
const
T
*
relative_attention_bias
=
nullptr
;
int
relative_attention_bias_stride
=
0
;
// The slope per head of linear position bias to attention score (H).
const
float
*
linear_bias_slopes
=
nullptr
;
const
T
*
ia3_key_weights
=
nullptr
;
const
T
*
ia3_value_weights
=
nullptr
;
const
int
*
ia3_tasks
=
nullptr
;
const
float
*
qkv_scale_out
=
nullptr
;
const
float
*
attention_out_scale
=
nullptr
;
int
int8_mode
=
0
;
};
template
<
typename
T
,
bool
CROSS_ATTENTION
>
struct
Multihead_attention_params
:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
// will need it here till if constexpr in c++17
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
typename
T
>
struct
Multihead_attention_params
<
T
,
true
>:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
class
T
>
using
Masked_multihead_attention_params
=
Multihead_attention_params
<
T
,
false
>
;
template
<
class
T
>
using
Cross_multihead_attention_params
=
Multihead_attention_params
<
T
,
true
>
;
template
<
typename
T
>
struct
outputCrossAttentionParam
{
// max decoder output length
int
max_decoder_seq_len
=
0
;
T
*
cross_attention_out
=
nullptr
;
bool
is_return_cross_attentions
=
false
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
0 → 100644
View file @
1b0af2d3
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <assert.h>
#include <float.h>
#include <type_traits>
// #define MMHA_USE_HMMA_FOR_REDUCTION
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
#define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
// Does not seem to improve the accuracy
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#endif
namespace
mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
//
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
// 64, 128 and 256 threads per block.
//
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
// cache buffer helps with memory accesses and contains keys with bias.
//
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
// values for x are chosen to create chunks of 16 bytes.
//
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
//
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
// shared memory.
//
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
// except for the current timestep. The layout of the cache buffer for the values is much simpler
// as it is [B, H, L, Dh].
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Dh
>
struct
Qk_vec_
{
};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
256
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
32
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
64
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
128
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
256
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{
};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_
<
float
,
1
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
4
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
2
>
{
using
Type
=
uint2
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
1
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
4
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
2
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
1
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
V_VEC_SIZE
>
struct
V_vec_
{
};
template
<
>
struct
V_vec_
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_
<
float
,
4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
8
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
2
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
4
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
8
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template
<
typename
T
>
struct
Qk_vec_acum_fp32_
{
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
template
<
>
struct
Qk_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
K_vec_acum_fp32_
{
};
template
<
>
struct
K_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
V_vec_acum_fp32_
{
};
template
<
>
struct
V_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
#ifdef ENABLE_BF16
template
<
>
struct
V_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
V_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
#endif // ENABLE_BF16
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
K_vec_acum
=
typename
K_vec_acum_fp32_
<
K_vec
>::
Type
;
#else
using
K_vec_acum
=
K_vec
;
#endif
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
K_vec
,
K_vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREADS_PER_KEY
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
hmma_fp32
(
const
uint2
&
a
,
uint32_t
b
)
{
float4
c
;
float
zero
=
0.
f
;
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32
\n
"
" {%0, %1, %2, %3},
\n
"
" {%4, %5},
\n
"
" {%6},
\n
"
" {%7, %7, %7, %7};
\n
"
:
"=f"
(
c
.
x
),
"=f"
(
c
.
y
),
"=f"
(
c
.
z
),
"=f"
(
c
.
w
)
:
"r"
(
a
.
x
)
"r"
(
a
.
y
),
"r"
(
b
),
"f"
(
zero
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
float
qk_hmma_dot_
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
K_vec_acum
=
typename
K_vec_acum_fp32_
<
uint32_t
>::
Type
;
#else
using
K_vec_acum
=
uint32_t
;
#endif
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
uint32_t
,
uint32_t
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t
qk_vec_
=
float2_to_half2
(
qk_vec
);
return
hmma_fp32
(
make_uint2
(
qk_vec_
,
0u
),
0x3c003c00u
).
x
;
#else
return
hmma_fp32
(
make_uint2
(
qk_vec
,
0u
),
0x3c003c00u
).
x
;
#endif
#else
return
0.
f
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Qk_dot
<
uint16_t
,
4
>
{
template
<
int
N
>
static
inline
__device__
float
dot
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return
qk_hmma_dot_
(
q
,
k
);
#else
return
qk_dot_
<
4
>
(
q
,
k
);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_PER_BLOCK
,
int
WARP_SIZE
=
32
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The warps compute the final sums.
if
(
lane
<
WARPS_PER_BLOCK
)
{
sum
=
red_smem
[
lane
];
}
// Parallel reduction inside the warp.
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Broadcast to other threads.
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float
&
dst
,
float
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint16_t
&
dst
,
float
src
)
{
dst
=
float_to_half
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint32_t
&
dst
,
float2
src
)
{
dst
=
float2_to_half2
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
void
convert_from_float
(
__nv_bfloat16
&
dst
,
float
src
)
{
dst
=
__float2bfloat16
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
=
__float22bfloat162_rn
(
src
);
#else
dst
=
__floats2bfloat162_rn
(
src
.
x
,
src
.
y
);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint2
&
dst
,
Float4_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint2
&
dst
,
float4
src
)
{
convert_from_float
(
dst
,
Float4_
{
make_float2
(
src
.
x
,
src
.
y
),
make_float2
(
src
.
z
,
src
.
w
)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint4
&
dst
,
Float8_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
dst
.
z
=
float2_to_half2
(
src
.
z
);
dst
.
w
=
float2_to_half2
(
src
.
w
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
void
convert_from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
#else
dst
.
x
=
__floats2bfloat162_rn
(
src
.
x
.
x
,
src
.
x
.
y
);
dst
.
y
=
__floats2bfloat162_rn
(
src
.
y
.
x
,
src
.
y
.
y
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
bf16_4_t
&
dst
,
float4
src
)
{
convert_from_float
(
dst
,
Float4_
{
make_float2
(
src
.
x
,
src
.
y
),
make_float2
(
src
.
z
,
src
.
w
)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
#else
dst
.
x
=
__floats2bfloat162_rn
(
src
.
x
.
x
,
src
.
x
.
y
);
dst
.
y
=
__floats2bfloat162_rn
(
src
.
y
.
x
,
src
.
y
.
y
);
dst
.
z
=
__floats2bfloat162_rn
(
src
.
z
.
x
,
src
.
z
.
y
);
dst
.
w
=
__floats2bfloat162_rn
(
src
.
w
.
x
,
src
.
w
.
y
);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float2
&
dst
,
float2
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float4
&
dst
,
float4
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
convert_to_float
(
float4
u
)
{
return
u
.
x
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
convert_to_float
(
uint4
u
)
{
float2
tmp
=
half2_to_float2
(
u
.
x
);
return
tmp
.
x
;
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
cast_to_float
(
float
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
float2
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
cast_to_float
(
float4
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
Float4_
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
cast_to_float
(
Float8_
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
uint32_t
u
)
{
return
half2_to_float2
(
u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
uint2
u
)
{
Float4_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
return
tmp
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
cast_to_float
(
uint4
u
)
{
Float8_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
tmp
.
z
=
half2_to_float2
(
u
.
z
);
tmp
.
w
=
half2_to_float2
(
u
.
w
);
return
tmp
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
float_from_int8
(
int8_t
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
float_from_int8
(
int16_t
u
)
{
union
{
int16_t
int16
;
int8_t
int8
[
2
];
};
int16
=
u
;
return
make_float2
(
int8
[
0
],
int8
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
float_from_int8
(
int32_t
u
)
{
union
{
int32_t
int32
;
int8_t
int8
[
4
];
};
int32
=
u
;
return
make_float4
(
int8
[
0
],
int8
[
1
],
int8
[
2
],
int8
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// clang-format off
inline
__device__
Float8_
float_from_int8
(
int64_t
u
)
{
union
{
int64_t
int64
;
int16_t
int16
[
4
];
};
int64
=
u
;
return
Float8_
{
float_from_int8
(
int16
[
0
]),
float_from_int8
(
int16
[
1
]),
float_from_int8
(
int16
[
2
]),
float_from_int8
(
int16
[
3
])};
}
// clang-format on
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int8_t
cast_to_int8
(
float
val
)
{
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
asm
volatile
(
"cvt.rni.sat.s8.f32 %0, %1;"
:
"=h"
(
int16
)
:
"f"
(
val
));
return
int8
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int32_t
cast_to_int8
(
float4
val
)
{
union
{
int8_t
int8
[
4
];
int32_t
int32
;
};
int8
[
0
]
=
cast_to_int8
(
val
.
x
);
int8
[
1
]
=
cast_to_int8
(
val
.
y
);
int8
[
2
]
=
cast_to_int8
(
val
.
z
);
int8
[
3
]
=
cast_to_int8
(
val
.
w
);
return
int32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int64_t
cast_to_int8
(
Float8_
val
)
{
union
{
int8_t
int8
[
8
];
int64_t
int64
;
};
int8
[
0
]
=
cast_to_int8
(
val
.
x
.
x
);
int8
[
1
]
=
cast_to_int8
(
val
.
x
.
y
);
int8
[
2
]
=
cast_to_int8
(
val
.
y
.
x
);
int8
[
3
]
=
cast_to_int8
(
val
.
y
.
y
);
int8
[
4
]
=
cast_to_int8
(
val
.
z
.
x
);
int8
[
5
]
=
cast_to_int8
(
val
.
z
.
y
);
int8
[
6
]
=
cast_to_int8
(
val
.
w
.
x
);
int8
[
7
]
=
cast_to_int8
(
val
.
w
.
y
);
return
int64
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
bool
DO_CROSS_ATTENTION
>
inline
size_t
smem_size_in_bytes
(
const
Multihead_attention_params
<
T
,
DO_CROSS_ATTENTION
>&
params
,
int
threads_per_value
,
int
threads_per_block
)
{
// The amount of shared memory needed to store the Q*K^T values in float.
const
int
max_timesteps
=
min
(
params
.
timestep
,
params
.
memory_max_len
);
size_t
qk_sz
=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
16
:
div_up
(
max_timesteps
+
1
,
4
)
*
16
;
// The extra memory needed if we are not using floats for the final logits.
size_t
logits_sz
=
0
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
// TDOD
logits_sz
=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
4
*
sizeof
(
T
)
:
div_up
(
max_timesteps
+
1
,
4
)
*
4
*
sizeof
(
T
);
}
#endif
// The total size needed during softmax.
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
// The number of partial rows to reduce in the final reduction.
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
// The amount of storage needed to finalize the outputs.
size_t
red_sz
=
rows_per_red
*
params
.
hidden_size_per_head
*
sizeof
(
T
)
/
2
;
size_t
transpose_rotary_size
=
0
;
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
transpose_rotary_size
=
2
*
params
.
rotary_embedding_dim
*
sizeof
(
T
);
}
// The max.
return
max
(
max
(
softmax_sz
,
red_sz
),
transpose_rotary_size
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
constexpr
uint32_t
shfl_mask
(
int
threads
)
{
return
threads
==
32
?
uint32_t
(
-
1
)
:
(
1u
<<
threads
)
-
1u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The type of the inputs. Supported types: float and half.
typename
T
,
// The hidden dimension per head.
int
Dh
,
int
Dh_MAX
,
// The number of threads per key.
int
THREADS_PER_KEY
,
// The number of threads per value.
int
THREADS_PER_VALUE
,
// The number of threads in a threadblock.
int
THREADS_PER_BLOCK
,
bool
DO_CROSS_ATTENTION
>
__global__
void
masked_multihead_attention_kernel
(
Multihead_attention_params
<
T
,
DO_CROSS_ATTENTION
>
params
)
{
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
static_assert
(
Dh_MAX
%
THREADS_PER_KEY
==
0
,
""
);
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
static_assert
(
Dh_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
// The size of a warp.
constexpr
int
WARP_SIZE
=
32
;
// The number of warps in a threadblock.
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
extern
__shared__
char
smem_
[];
// The shared memory for the Q*K^T values and partial logits in softmax.
float
*
qk_smem
=
reinterpret_cast
<
float
*>
(
smem_
);
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
char
*
logits_smem_
=
smem_
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
// TODO - change to tlength
const
int
max_timesteps
=
min
(
params
.
timestep
,
params
.
memory_max_len
);
logits_smem_
+=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
16
:
div_up
(
max_timesteps
+
1
,
4
)
*
16
;
}
T
*
logits_smem
=
reinterpret_cast
<
T
*>
(
logits_smem_
);
#else
float
*
logits_smem
=
reinterpret_cast
<
float
*>
(
logits_smem_
);
#endif
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
// A vector of Q or K elements for the current timestep.
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
// Use alignment for safely casting the shared buffers as Qk_vec.
// Shared memory to store Q inputs.
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
// This is one of the reasons we should have a separate kernel for cross attention
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
bias_smem
[
DO_CROSS_ATTENTION
?
Dh_MAX
:
1
];
// A vector of Q or K elements for the current timestep.
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
// The number of elements per vector.
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
// We will use block wide reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
// The number of vectors per warp.
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.
// The number of elements in a chunk of 16B (that's the x in the above formula).
constexpr
int
QK_ELTS_IN_16B
=
16
/
sizeof
(
T
);
// The number of K vectors in 16B.
constexpr
int
QK_VECS_IN_16B
=
16
/
sizeof
(
Qk_vec
);
// The batch/beam idx
const
int
bi
=
blockIdx
.
y
;
if
(
params
.
finished
!=
nullptr
&&
params
.
finished
[
bi
]
==
true
)
{
return
;
}
// The beam idx
const
int
beami
=
bi
%
params
.
beam_width
;
// The "beam-aware" batch idx
const
int
bbi
=
bi
/
params
.
beam_width
;
// The head.
const
int
num_kv_heads
=
params
.
num_kv_heads
;
const
int
kv_rep
=
(
params
.
num_heads
/
num_kv_heads
);
const
int
hi
=
blockIdx
.
x
;
const
int
hi_kv
=
hi
/
kv_rep
;
// Combine the batch and the head indices.
const
int
bhi
=
bi
*
params
.
num_heads
+
hi
;
const
int
bhi_kv
=
bi
*
(
params
.
num_heads
/
kv_rep
)
+
hi_kv
;
// Combine the "beam-aware" batch idx and the head indices.
const
int
bbhi
=
bbi
*
params
.
beam_width
*
params
.
num_heads
+
hi
;
const
int
bbhi_kv
=
bbi
*
params
.
beam_width
*
(
params
.
num_heads
/
kv_rep
)
+
hi_kv
;
// The thread in the block.
const
int
tidx
=
threadIdx
.
x
;
const
bool
handle_kv
=
!
DO_CROSS_ATTENTION
||
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
);
// Every kv_rep threads have the same kv_cache values. So only the first one writes back.
const
int
write_kv_cache
=
handle_kv
&&
(
hi
%
kv_rep
==
0
);
// While doing the product Q*K^T for the different keys we track the max.
float
qk_max
=
-
FLT_MAX
;
float
qk
=
0.0
F
;
// int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const
int
q_base_offset
=
bi
*
params
.
stride
+
hi
*
Dh
;
const
int
k_base_offset
=
bi
*
params
.
stride
+
hi_kv
*
Dh
;
const
int
v_base_offset
=
k_base_offset
;
const
size_t
bi_seq_len_offset
=
bi
*
params
.
memory_max_len
;
// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
int
tlength
=
(
DO_CROSS_ATTENTION
)
?
params
.
memory_length_per_sample
[
bi
]
-
1
:
(
params
.
length_per_sample
==
nullptr
)
?
params
.
timestep
:
params
.
length_per_sample
[
bi
]
+
params
.
max_prefix_prompt_length
;
const
int
first_step
=
max
(
0
,
tlength
+
1
-
params
.
memory_max_len
);
const
int
tlength_circ
=
tlength
%
params
.
memory_max_len
;
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const
bool
is_masked
=
tidx
>=
QK_VECS_PER_WARP
;
// The offset in the Q and K buffer also accounts for the batch.
// int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
int
q_offset
=
q_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
k_offset
=
k_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
v_offset
=
k_offset
;
// The offset in the bias buffer.
// int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int
q_bias_offset
=
hi
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
k_bias_offset
=
hi_kv
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
v_bias_offset
=
k_bias_offset
;
const
bool
do_ia3
=
handle_kv
&&
params
.
ia3_tasks
!=
nullptr
;
const
int
ia3_task_id
=
do_ia3
?
params
.
ia3_tasks
[
bbi
]
:
0
;
// Trigger the loads from the Q and K buffers.
Qk_vec
q
;
zero
(
q
);
if
(
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
))
{
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
q_scaling
=
params
.
qkv_scale_out
[
0
];
const
auto
q_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
q
)[
q_offset
]);
convert_from_float
(
q
,
mul
<
Packed_Float_t
,
float
>
(
q_scaling
,
float_from_int8
(
q_quant
)));
}
else
{
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q
[
q_offset
]);
}
}
Qk_vec
k
;
zero
(
k
);
if
(
DO_CROSS_ATTENTION
)
{
// The 16B chunk written by the thread.
int
co
=
tidx
/
QK_VECS_IN_16B
;
// The position of the thread in that 16B chunk.
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength
*
QK_ELTS_IN_16B
+
ci
;
k
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
:
k
;
}
else
{
if
(
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
))
{
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
k_scaling
=
params
.
qkv_scale_out
[
1
];
const
auto
k_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
k
)[
k_offset
]);
convert_from_float
(
k
,
mul
<
Packed_Float_t
,
float
>
(
k_scaling
,
float_from_int8
(
k_quant
)));
}
else
{
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k
[
k_offset
]);
}
}
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
!
is_masked
&&
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
q_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q_bias
[
q_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
if
(
handle_kv
)
{
k_bias
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
k_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_bias
[
k_bias_offset
])
:
k_bias
;
}
// Computes the Q/K values with bias.
q
=
add
(
q
,
q_bias
);
if
(
handle_kv
)
{
k
=
add
(
k
,
k_bias
);
}
if
(
do_ia3
&&
!
is_masked
)
{
k
=
mul
<
Qk_vec
,
Qk_vec
,
Qk_vec
>
(
k
,
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
ia3_key_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
tidx
*
QK_VEC_SIZE
]));
}
// Padded len
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
handle_kv
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
}
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
const
bool
do_rotary
=
!
is_masked
&&
QK_VEC_SIZE
*
tidx
<
params
.
rotary_embedding_dim
;
T
*
q_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
T
*
k_smem
=
q_smem
+
params
.
rotary_embedding_dim
;
const
int
half_rotary_dim
=
params
.
rotary_embedding_dim
/
2
;
const
int
half_idx
=
(
tidx
*
QK_VEC_SIZE
)
/
half_rotary_dim
;
const
int
intra_half_idx
=
(
tidx
*
QK_VEC_SIZE
)
%
half_rotary_dim
;
const
int
smem_pitch
=
half_rotary_dim
;
// TODO: adjust for bank conflicts
assert
(
half_rotary_dim
%
QK_VEC_SIZE
==
0
);
if
(
do_rotary
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
q_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
)
=
q
;
if
(
handle_kv
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
k_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
)
=
k
;
}
}
__syncthreads
();
const
int
transpose_idx
=
half_idx
*
(
half_rotary_dim
/
2
)
+
intra_half_idx
/
2
;
constexpr
int
tidx_factor
=
(
QK_VEC_SIZE
>
1
)
?
QK_VEC_SIZE
/
2
:
1
;
if
(
do_rotary
)
{
mmha
::
vec_from_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
if
(
handle_kv
)
{
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
}
__syncthreads
();
if
(
do_rotary
)
{
q
=
*
reinterpret_cast
<
Qk_vec
*>
(
q_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
);
if
(
handle_kv
)
{
k
=
*
reinterpret_cast
<
Qk_vec
*>
(
k_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
);
}
}
__syncthreads
();
}
if
(
!
is_masked
)
{
// Store the Q values to shared memory.
*
reinterpret_cast
<
Qk_vec
*>
(
&
q_smem
[
tidx
*
QK_VEC_SIZE
])
=
q
;
// Store Dh values of k_bias into smem, since will need to add later
// if params.timestep == 0
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
bias_smem
[
tidx
*
QK_VEC_SIZE
])
=
k_bias
;
}
// Write the K values to the global memory cache.
//
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
// system. We designed it this way as it allows much better memory loads (and there are many
// more loads) + the stores are really "write and forget" since we won't need the ack before
// the end of the kernel. There's plenty of time for the transactions to complete.
// The 16B chunk written by the thread.
int
co
=
tidx
/
QK_VECS_IN_16B
;
// The position of the thread in that 16B chunk.
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength_circ
*
QK_ELTS_IN_16B
+
ci
;
if
(
write_kv_cache
)
{
// Trigger the stores to global memory.
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
=
k
;
}
}
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
Qk_vec_acum
=
typename
Qk_vec_acum_fp32_
<
Qk_vec
>::
Type
;
#else
using
Qk_vec_acum
=
Qk_vec
;
#endif
qk
=
dot
<
Qk_vec_acum
,
Qk_vec
>
(
q
,
k
);
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
#pragma unroll
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
if
(
tidx
==
0
)
{
// Normalize qk.
qk
*=
params
.
inv_sqrt_dh
;
if
(
params
.
relative_attention_bias
!=
nullptr
)
{
// TODO (Haotian): check whether we should replace hi with hi_kv,
// although params.relative_attention_bias is usually not used.
qk
=
add
(
qk
,
params
.
relative_attention_bias
[
hi
*
params
.
relative_attention_bias_stride
*
params
.
relative_attention_bias_stride
+
(
tlength
-
padd_len
)
*
params
.
relative_attention_bias_stride
+
(
tlength
-
padd_len
)]);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
qk_max
=
qk
;
qk_smem
[
tlength
-
first_step
]
=
qk
;
// qk_smem[params.timestep] = qk;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The type of queries and keys for the math in the Q*K^T product.
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
// The number of elements per vector.
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert
(
Dh_MAX
%
K_VEC_SIZE
==
0
,
""
);
// The number of elements per thread.
constexpr
int
K_ELTS_PER_THREAD
=
Dh_MAX
/
THREADS_PER_KEY
;
// The number of vectors per thread.
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
// The position the first key loaded by each thread from the cache buffer (for this B * H).
int
ko
=
tidx
/
THREADS_PER_KEY
;
// The position of the thread in the chunk of keys.
int
ki
=
tidx
%
THREADS_PER_KEY
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
);
// Load the Q values from shared memory. The values are reused during the loop on K.
K_vec
q_vec
[
K_VECS_PER_THREAD
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
q_vec
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
q_smem
[
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
K_vec
k_bias_vec
[
DO_CROSS_ATTENTION
?
K_VECS_PER_THREAD
:
1
];
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
k_bias_vec
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
bias_smem
[
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
}
// The number of timesteps loaded per iteration.
constexpr
int
K_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_KEY
;
// The number of keys per warp.
constexpr
int
K_PER_WARP
=
WARP_SIZE
/
THREADS_PER_KEY
;
// The base pointer for the key in the cache buffer.
T
*
k_cache
=
&
params
.
k_cache
[
bhi_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
k_cache_batch
=
&
params
.
k_cache
[
bbhi_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int
ti_end
=
div_up
(
tlength
-
first_step
,
K_PER_WARP
)
*
K_PER_WARP
+
first_step
;
// prefix prompt length if has
const
int
prefix_prompt_length
=
(
params
.
prefix_prompt_lengths
==
nullptr
)
?
0
:
params
.
prefix_prompt_lengths
[
bi
];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const
bool
has_beams
=
params
.
cache_indir
!=
nullptr
;
const
int
*
beam_indices
=
has_beams
?
&
params
.
cache_indir
[
bi_seq_len_offset
]
:
nullptr
;
for
(
int
ti
=
first_step
+
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
const
int
ti_circ
=
ti
%
params
.
memory_max_len
;
// The keys loaded from the key cache.
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
memory_max_len
+
ti_circ
;
// if( ti < params.timestep ) {
const
bool
within_bounds
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
memory_max_len
);
if
(
ti
<
tlength
)
{
if
(
!
within_bounds
)
{
k
[
ii
]
=
k_vec_zero
;
}
else
{
if
(
has_beams
)
{
const
int
beam_offset
=
beam_indices
[
ti_circ
]
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache_batch
[
beam_offset
+
jj
*
QK_ELTS_IN_16B
]);
}
else
{
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache_batch
[
jj
*
QK_ELTS_IN_16B
]);
}
}
// add bias and update k_cache
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
k
[
ii
]
=
add
(
k
[
ii
],
k_bias_vec
[
ii
]);
if
(
do_ia3
)
{
k
[
ii
]
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
k
[
ii
],
*
reinterpret_cast
<
const
K_vec
*>
(
&
params
.
ia3_key_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]));
}
if
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
memory_max_len
)
{
*
reinterpret_cast
<
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
=
k
[
ii
];
}
}
}
}
// Perform the dot product and normalize qk.
//
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q_vec
,
k
)
*
params
.
inv_sqrt_dh
;
bool
is_mask
=
(
params
.
masked_tokens
!=
nullptr
)
&&
params
.
masked_tokens
[
bi_seq_len_offset
+
ti
];
// Store the product to shared memory. There's one qk value per timestep. Update the max.
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
if
(
ti
<
tlength
&&
tidx
%
THREADS_PER_KEY
==
0
)
{
if
(
params
.
relative_attention_bias
!=
nullptr
)
{
qk
=
add
(
qk
,
params
.
relative_attention_bias
[
hi
*
params
.
relative_attention_bias_stride
*
params
.
relative_attention_bias_stride
+
tlength
*
params
.
relative_attention_bias_stride
+
ti
]);
}
if
(
params
.
linear_bias_slopes
!=
nullptr
)
{
// Apply the linear position bias: (ki - qi) * slope[hi].
// The padding token locates between the input context and the generated tokens.
// We need to remove the number of padding tokens in the distance computation.
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
// token: i i i i p p p o o o where i=input, p=pad, o=output.
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
int
max_context_length
=
params
.
max_prefix_prompt_length
+
params
.
max_input_length
;
float
dist
=
(
ti
<
max_context_length
?
ti
+
padd_len
:
ti
)
-
tlength
;
qk
+=
mul
<
float
,
float
,
float
>
(
params
.
linear_bias_slopes
[
hi
],
dist
);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
qk_max
=
is_mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_smem
[
ti
-
first_step
]
=
qk
;
}
}
// Perform the final reduction to compute the max inside each warp.
//
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
// group so it's not needed to run the reduction inside the group (again).
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREADS_PER_KEY
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Decompose the thread index into warp and lane.
const
int
warp
=
tidx
/
WARP_SIZE
;
const
int
lane
=
tidx
%
WARP_SIZE
;
// The warp leader writes the max to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
qk_max
;
}
// Make sure the products are in shared memory.
__syncthreads
();
// The warps finalize the reduction.
qk_max
=
lane
<
WARPS_PER_BLOCK
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Broadcast to all the threads in the warp.
qk_max
=
__shfl_sync
(
uint32_t
(
-
1
),
qk_max
,
0
);
// Compute the logits and start the sum.
float
sum
=
0.
f
;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for
(
int
ti
=
first_step
+
tidx
;
ti
<=
tlength
;
ti
+=
THREADS_PER_BLOCK
)
{
bool
is_mask
=
(
params
.
masked_tokens
!=
nullptr
)
&&
params
.
masked_tokens
[
bi_seq_len_offset
+
ti
];
float
logit
=
is_mask
?
0.
f
:
__expf
(
qk_smem
[
ti
-
first_step
]
-
qk_max
);
sum
+=
logit
;
qk_smem
[
ti
-
first_step
]
=
logit
;
}
// Compute the sum.
sum
=
block_sum
<
WARPS_PER_BLOCK
>
(
&
red_smem
[
WARPS_PER_BLOCK
],
sum
);
// Normalize the logits.
float
inv_sum
=
__fdividef
(
1.
f
,
sum
+
1.e-6
f
);
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
const
size_t
cross_attention_out_offset
=
params
.
is_return_cross_attentions
?
bhi_kv
*
params
.
max_decoder_seq_len
*
params
.
memory_max_len
+
params
.
timestep
*
params
.
memory_max_len
:
0
;
for
(
int
ti
=
first_step
+
tidx
;
ti
<=
tlength
;
ti
+=
THREADS_PER_BLOCK
)
{
float
logit
=
qk_smem
[
ti
-
first_step
]
*
inv_sum
;
if
(
params
.
is_return_cross_attentions
)
{
params
.
cross_attention_out
[
cross_attention_out_offset
+
ti
]
=
logit
;
}
convert_from_float
(
logits_smem
[
ti
-
first_step
],
logit
);
}
// Put Values part below so we leverage __syncthreads
// from the previous step
// The number of elements per vector.
constexpr
int
V_VEC_SIZE
=
Dh_MAX
/
THREADS_PER_VALUE
;
// A vector of V elements for the current timestep.
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
// The value computed by this thread.
int
vo
=
tidx
/
THREADS_PER_VALUE
;
// The hidden dimensions computed by this particular thread.
int
vi
=
tidx
%
THREADS_PER_VALUE
*
V_VEC_SIZE
;
// The base pointer for the value in the cache buffer.
T
*
v_cache
=
&
params
.
v_cache
[
bhi_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
v_cache_batch
=
&
params
.
v_cache
[
bbhi_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// The number of values processed per iteration of the loop.
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
// One group of threads computes the product(s) for the current timestep.
V_vec
v_bias
;
zero
(
v_bias
);
// if( vo == params.timestep % V_PER_ITER ) {
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
if
(
handle_kv
)
{
if
(
vo
==
tlength
%
V_PER_ITER
)
{
// Trigger the loads from the V bias buffer.
if
(
params
.
v_bias
!=
nullptr
)
{
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v_bias
[
hi_kv
*
Dh
+
vi
]);
}
if
(
DO_CROSS_ATTENTION
)
{
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
])
=
v_bias
;
}
}
}
}
// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads
();
// Values continued
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using
V_vec_acum
=
typename
V_vec_acum_fp32_
<
V_vec
>::
Type
;
#else
using
V_vec_acum
=
V_vec
;
#endif
// The partial outputs computed by each thread.
V_vec_acum
out
;
zero
(
out
);
// Loop over the timesteps to compute the partial outputs.
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
first_step
+
vo
;
ti
<
tlength
;
ti
+=
V_PER_ITER
)
{
const
int
ti_circ
=
ti
%
params
.
memory_max_len
;
// Fetch offset based on cache_indir when beam sampling
const
int
beam_src
=
(
params
.
cache_indir
!=
nullptr
)
?
params
.
cache_indir
[
bi_seq_len_offset
+
ti_circ
]
:
0
;
const
int
beam_offset
=
beam_src
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
// Load the values from the cache.
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache_batch
[
beam_offset
+
ti_circ
*
Dh
]);
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
v
=
add
(
v
,
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
]));
if
(
do_ia3
)
{
v
=
mul
<
V_vec
,
V_vec
,
V_vec
>
(
v
,
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
ia3_value_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
vi
]));
}
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
ti
*
Dh
])
=
v
;
}
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float
logit
=
logits_smem
[
ti
-
first_step
];
out
=
fma
(
logit
,
cast_to_float
(
v
),
out
);
#else
T
logit
=
logits_smem
[
ti
-
first_step
];
// Update the partial sums.
out
=
fma
(
logit
,
v
,
out
);
#endif
}
}
// One group of threads computes the product(s) for the current timestep.
// if( vo == params.timestep % V_PER_ITER ) {
if
(
vo
==
tlength
%
V_PER_ITER
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
;
if
(
DO_CROSS_ATTENTION
)
{
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
tlength
*
Dh
]);
}
else
{
// Trigger the loads from the V buffer.
const
auto
v_offset
=
v_base_offset
+
vi
;
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
V_vec
>::
value
>::
type
;
const
auto
v_scaling
=
params
.
qkv_scale_out
[
2
];
const
auto
v_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
v
)[
v_offset
]);
convert_from_float
(
v
,
mul
<
Packed_Float_t
,
float
>
(
v_scaling
,
float_from_int8
(
v_quant
)));
}
else
{
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v
[
v_offset
]);
}
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi*Dh + vi]);
}
// Compute the V values with bias.
v
=
add
(
v
,
v_bias
);
if
(
write_kv_cache
)
{
if
(
do_ia3
)
{
v
=
mul
<
V_vec
,
V_vec
,
V_vec
>
(
v
,
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
ia3_value_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
vi
]));
}
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
v
;
}
// Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out
=
fma
(
logits_smem
[
tlength
-
first_step
],
cast_to_float
(
v
),
out
);
#else
// out = fma(logits_smem[params.timestep], v, out);
out
=
fma
(
logits_smem
[
tlength
-
first_step
],
v
,
out
);
#endif
}
// Make sure we can start writing to shared memory.
__syncthreads
();
// Run the final reduction amongst the different groups computing different partial outputs.
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
#pragma unroll
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
// The midpoint in the number of active groups.
int
midpoint
=
active_groups
/
2
;
// The upper part of active threads store to shared memory.
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
}
__syncthreads
();
// The bottom warps update their values.
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
// Output the final values.
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec_acum
>::
value
>::
type
;
out
=
mul
<
V_vec_acum
,
float
>
(
*
params
.
attention_out_scale
,
out
);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
(
reinterpret_cast
<
int8_t
*>
(
params
.
out
)[
bhi
*
Dh
+
vi
]))
=
cast_to_int8
(
out
);
}
else
{
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
}
#else
// TODO: support int8_mode?
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
])
=
out
;
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace mmha
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
typename
KERNEL_PARAMS_TYPE
>
void
mmha_launch_kernel
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
);
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
0 → 100644
View file @
1b0af2d3
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <stdint.h>
using
namespace
fastertransformer
;
namespace
mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Float8_
{
float2
x
;
float2
y
;
float2
z
;
float2
w
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Float4_
{
float2
x
;
float2
y
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct
bf16_4_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
bf16_8_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
__nv_bfloat162
z
;
__nv_bfloat162
w
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
num_elems
;
template
<
>
struct
num_elems
<
float
>
{
static
constexpr
int
value
=
1
;
};
template
<
>
struct
num_elems
<
float2
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
float4
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
Float4_
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
Float8_
>
{
static
constexpr
int
value
=
8
;
};
template
<
>
struct
num_elems
<
uint32_t
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
uint2
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
uint4
>
{
static
constexpr
int
value
=
8
;
};
#ifdef ENABLE_BF16
template
<
>
struct
num_elems
<
__nv_bfloat162
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
bf16_4_t
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
bf16_8_t
>
{
static
constexpr
int
value
=
8
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
N
>
struct
packed_type
;
template
<
typename
T
>
struct
packed_type
<
T
,
1
>
{
using
type
=
T
;
};
template
<
>
struct
packed_type
<
int8_t
,
2
>
{
using
type
=
int16_t
;
};
template
<
>
struct
packed_type
<
int8_t
,
4
>
{
using
type
=
int32_t
;
};
template
<
>
struct
packed_type
<
int8_t
,
8
>
{
using
type
=
int64_t
;
};
template
<
>
struct
packed_type
<
float
,
2
>
{
using
type
=
float2
;
};
template
<
>
struct
packed_type
<
float
,
4
>
{
using
type
=
float4
;
};
template
<
>
struct
packed_type
<
float
,
8
>
{
using
type
=
Float8_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
float
b
)
{
return
a
+
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
add
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
add
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
a
+
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
bf16hadd2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
add
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint16_t
add
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"add.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
add
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
add
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
add
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
float zero = 0.f;
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
));
#endif
return
tmp
.
u16
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
float2_to_half2
(
float2
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
tmp
.
u32
)
:
"f"
(
f
.
y
),
"f"
(
f
.
x
));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
.
x
));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
1
])
:
"f"
(
f
.
y
));
#endif
return
tmp
.
u32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
half2_to_float2
(
uint32_t
v
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
v
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
uint16_t
b
)
{
return
a
+
half_to_float
(
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
add
(
float
a
,
__nv_bfloat16
b
)
{
return
a
+
__bfloat162float
(
b
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
add
(
uint32_t
a
,
float2
fb
)
{
float2
fa
=
half2_to_float2
(
a
);
return
add
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
add
(
uint2
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
add
(
uint4
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
h0_h0
(
uint16_t
a
)
{
uint32_t
b
;
asm
volatile
(
"mov.b32 %0, {%1, %1};"
:
"=r"
(
b
)
:
"h"
(
a
));
return
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
float
a
,
float
b
,
float
c
)
{
return
a
*
b
+
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
float2
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
float
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
fma
(
float4
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
fma
(
float
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
float
a
,
Float4_
b
,
Float4_
c
)
{
Float4_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
float
a
,
Float8_
b
,
Float8_
c
)
{
Float8_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float2
add
(
__nv_bfloat162
a
,
float2
fb
)
{
float2
fa
=
bf1622float2
(
a
);
return
add
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
add
(
bf16_4_t
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
add
(
bf16_8_t
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
fma
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
fma
(
uint16_t
a
,
uint32_t
b
,
uint32_t
c
)
{
return
fma
(
h0_h0
(
a
),
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
fma
(
uint2
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
fma
(
uint16_t
a
,
uint2
b
,
uint2
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
fma
(
uint4
a
,
uint4
b
,
uint4
c
)
{
uint4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
fma
(
uint16_t
a
,
uint4
b
,
uint4
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
uint16_t
a
,
uint16_t
b
,
float
fc
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
+
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
uint32_t
a
,
uint32_t
b
,
float2
fc
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
uint16_t
a
,
uint32_t
b
,
float2
fc
)
{
return
fma
(
h0_h0
(
a
),
b
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
uint2
a
,
uint2
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
uint16_t
a
,
uint2
b
,
Float4_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
uint4
a
,
uint4
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
uint16_t
a
,
uint4
b
,
Float8_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
bf16hfma2
(
a
,
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
bf16hfma2
(
bf162bf162
(
a
),
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
bf16_4_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
bf16_8_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
float
fc
)
{
return
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
+
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
float2
fc
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
float2
fc
)
{
return
fma
(
bf162bf162
(
a
),
b
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
Float4_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
Float8_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
>
inline
__device__
Acc
mul
(
A
a
,
B
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
<
float
,
float
>
(
float
a
,
float
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
float
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
c
.
z
=
a
.
z
*
b
.
z
;
c
.
w
=
a
.
w
*
b
.
w
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float4
mul
(
float
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
c
.
z
=
a
*
b
.
z
;
c
.
w
=
a
*
b
.
w
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
float
a
,
Float8_
b
)
{
Float8_
c
;
c
.
x
=
make_float2
(
a
*
b
.
x
.
x
,
a
*
b
.
x
.
y
);
c
.
y
=
make_float2
(
a
*
b
.
y
.
x
,
a
*
b
.
y
.
y
);
c
.
z
=
make_float2
(
a
*
b
.
z
.
x
,
a
*
b
.
z
.
y
);
c
.
w
=
make_float2
(
a
*
b
.
w
.
x
,
a
*
b
.
w
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint16_t
mul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"mul.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint32_t
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint2
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint4
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
uint16_t
b
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
float
b
)
{
return
half_to_float
(
a
)
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
uint32_t
a
,
uint32_t
b
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
float2
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
uint2
a
,
uint2
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
uint4
a
,
uint4
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return
__hmul
(
a
,
b
);
#else
return
bf16hmul
(
a
,
b
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
bf16hmul2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_4_t
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_4_t
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_8_t
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_8_t
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
float
fa
=
(
float
)
a
;
float
fb
=
(
float
)
b
;
return
fa
*
fb
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
float
b
)
{
return
__bfloat162float
(
a
)
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
fc
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
sum
(
__nv_bfloat162
v
)
{
float2
vf
=
bf1622float2
(
v
);
return
vf
.
x
+
vf
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
bf16_4_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
bf16_8_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
)
+
sum
(
v
.
z
)
+
sum
(
v
.
w
);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint16_t
v
)
{
return
half_to_float
(
v
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint32_t
v
)
{
float2
tmp
=
half2_to_float2
(
v
);
return
tmp
.
x
+
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint2
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
return
sum
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint4
v
)
{
#if 1
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
c
=
add
(
c
,
v
.
z
);
c
=
add
(
c
,
v
.
w
);
#else
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
uint32_t
d
=
add
(
v
.
z
,
v
.
w
);
c
=
add
(
c
,
d
);
#endif
return
sum
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
Float4_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
Float8_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
+
v
.
z
.
x
+
v
.
z
.
y
+
v
.
w
.
x
+
v
.
w
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
T
,
T
,
T
>
(
a
,
b
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
A
,
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
A
,
T
,
T
>
(
a
,
b
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
void
zero
(
T
&
dst
)
{
constexpr
int
WORDS
=
sizeof
(
T
)
/
4
;
union
{
T
raw
;
uint32_t
words
[
WORDS
];
}
tmp
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
WORDS
;
++
ii
)
{
tmp
.
words
[
ii
]
=
0u
;
}
dst
=
tmp
.
raw
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
float
t_step
,
const
float
base
)
{
const
float
inv_freq
=
t_step
/
pow
(
base
,
zid
/
(
float
)
rot_embed_dim
);
return
{
cos
(
inv_freq
),
sin
(
inv_freq
)};
}
inline
__device__
float2
rotary_embedding_transform
(
const
float2
v
,
const
float2
coef
)
{
float2
rot_v
;
rot_v
.
x
=
coef
.
x
*
v
.
x
-
coef
.
y
*
v
.
y
;
rot_v
.
y
=
coef
.
x
*
v
.
y
+
coef
.
y
*
v
.
x
;
return
rot_v
;
}
inline
__device__
uint32_t
rotary_embedding_transform
(
const
uint32_t
v
,
const
float2
coef
)
{
float2
fv
=
half2_to_float2
(
v
);
float2
rot_fv
=
rotary_embedding_transform
(
fv
,
coef
);
return
float2_to_half2
(
rot_fv
);
}
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat162
rotary_embedding_transform
(
const
__nv_bfloat162
v
,
const
float2
coef
)
{
float2
fv
=
bf1622float2
(
v
);
float2
rot_fv
=
rotary_embedding_transform
(
fv
,
coef
);
return
__floats2bfloat162_rn
(
rot_fv
.
x
,
rot_fv
.
y
);
}
#endif
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
float
&
k
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
float2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
float4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
Float4_
&
k_
=
*
reinterpret_cast
<
Float4_
*>
(
&
k
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
k_
.
x
=
rotary_embedding_transform
(
k_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
k_
.
y
=
rotary_embedding_transform
(
k_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
uint32_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
uint2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
uint4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#ifdef ENABLE_BF16
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
bf16_4_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
bf16_8_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#endif // ENABLE_BF16
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
vec_from_smem_transpose
(
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
return
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint32_t
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
tmp
.
u16
[
0
]
=
smem
[
transpose_idx
];
tmp
.
u16
[
1
]
=
smem
[
smem_pitch
+
transpose_idx
];
vec
=
tmp
.
u32
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint2
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
union
{
uint2
u32x2
;
uint16_t
u16
[
4
];
}
tmp_3
;
tmp_3
.
u16
[
0
]
=
tmp_1
.
u16
[
0
];
tmp_3
.
u16
[
1
]
=
tmp_2
.
u16
[
0
];
tmp_3
.
u16
[
2
]
=
tmp_1
.
u16
[
1
];
tmp_3
.
u16
[
3
]
=
tmp_2
.
u16
[
1
];
vec
=
tmp_3
.
u32x2
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint4
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
uint16_t
u16
[
4
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
union
{
uint4
u32x4
;
uint16_t
u16
[
8
];
}
tmp_3
;
tmp_3
.
u16
[
0
]
=
tmp_1
.
u16
[
0
];
tmp_3
.
u16
[
1
]
=
tmp_2
.
u16
[
0
];
tmp_3
.
u16
[
2
]
=
tmp_1
.
u16
[
1
];
tmp_3
.
u16
[
3
]
=
tmp_2
.
u16
[
1
];
tmp_3
.
u16
[
4
]
=
tmp_1
.
u16
[
2
];
tmp_3
.
u16
[
5
]
=
tmp_2
.
u16
[
2
];
tmp_3
.
u16
[
6
]
=
tmp_1
.
u16
[
3
];
tmp_3
.
u16
[
7
]
=
tmp_2
.
u16
[
3
];
vec
=
tmp_3
.
u32x4
;
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
bf16_4_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
__nv_bfloat16
bf16
[
2
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
vec
.
x
=
__nv_bfloat162
{
tmp_1
.
bf16
[
0
],
tmp_2
.
bf16
[
0
]};
vec
.
y
=
__nv_bfloat162
{
tmp_1
.
bf16
[
1
],
tmp_2
.
bf16
[
1
]};
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
bf16_8_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
__nv_bfloat16
bf16
[
4
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
vec
.
x
=
__nv_bfloat162
{
tmp_1
.
bf16
[
0
],
tmp_2
.
bf16
[
0
]};
vec
.
y
=
__nv_bfloat162
{
tmp_1
.
bf16
[
1
],
tmp_2
.
bf16
[
1
]};
vec
.
z
=
__nv_bfloat162
{
tmp_1
.
bf16
[
2
],
tmp_2
.
bf16
[
2
]};
vec
.
w
=
__nv_bfloat162
{
tmp_1
.
bf16
[
3
],
tmp_2
.
bf16
[
3
]};
}
#endif // ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float4
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
z
=
smem
[
transpose_idx
+
1
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
vec
.
w
=
smem
[
smem_pitch
+
transpose_idx
+
1
];
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint32_t
&
vec
,
half
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u16
[
0
]
=
smem
[
transpose_idx
];
tmp
.
u16
[
1
]
=
smem
[
smem_pitch
+
transpose_idx
];
vec
=
tmp
.
u32
;
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
__nv_bfloat162
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
}
#endif
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float2
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
}
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
write_smem_transpose
(
const
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
return
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint4
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
uint16_t
u16
[
4
];
}
tmp_1
,
tmp_2
;
union
{
uint4
u32x4
;
uint16_t
u16
[
8
];
}
tmp_3
;
tmp_3
.
u32x4
=
vec
;
tmp_1
.
u16
[
0
]
=
tmp_3
.
u16
[
0
];
tmp_2
.
u16
[
0
]
=
tmp_3
.
u16
[
1
];
tmp_1
.
u16
[
1
]
=
tmp_3
.
u16
[
2
];
tmp_2
.
u16
[
1
]
=
tmp_3
.
u16
[
3
];
tmp_1
.
u16
[
2
]
=
tmp_3
.
u16
[
4
];
tmp_2
.
u16
[
2
]
=
tmp_3
.
u16
[
5
];
tmp_1
.
u16
[
3
]
=
tmp_3
.
u16
[
6
];
tmp_2
.
u16
[
3
]
=
tmp_3
.
u16
[
7
];
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
])
=
tmp_1
.
u64
;
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
])
=
tmp_2
.
u64
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint2
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp_1
,
tmp_2
;
union
{
uint2
u32x2
;
uint16_t
u16
[
4
];
}
tmp_3
;
tmp_3
.
u32x2
=
vec
;
tmp_1
.
u16
[
0
]
=
tmp_3
.
u16
[
0
];
tmp_2
.
u16
[
0
]
=
tmp_3
.
u16
[
1
];
tmp_1
.
u16
[
1
]
=
tmp_3
.
u16
[
2
];
tmp_2
.
u16
[
1
]
=
tmp_3
.
u16
[
3
];
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
])
=
tmp_1
.
u32
;
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
])
=
tmp_2
.
u32
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint32_t
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
tmp
.
u32
=
vec
;
smem
[
transpose_idx
]
=
tmp
.
u16
[
0
];
smem
[
smem_pitch
+
transpose_idx
]
=
tmp
.
u16
[
1
];
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float4
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
transpose_idx
+
1
]
=
vec
.
z
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
smem
[
smem_pitch
+
transpose_idx
+
1
]
=
vec
.
w
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint32_t
&
vec
,
half
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u32
=
vec
;
smem
[
transpose_idx
]
=
tmp
.
u16
[
0
];
smem
[
smem_pitch
+
transpose_idx
]
=
tmp
.
u16
[
1
];
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
__nv_bfloat162
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
bf16_4_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
write_smem_transpose
(
reinterpret_cast
<
const
uint2
&>
(
vec
),
reinterpret_cast
<
uint16_t
*>
(
smem
),
transpose_idx
,
smem_pitch
);
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
bf16_8_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
write_smem_transpose
(
reinterpret_cast
<
const
uint4
&>
(
vec
),
reinterpret_cast
<
uint16_t
*>
(
smem
),
transpose_idx
,
smem_pitch
);
}
#endif
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float2
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
}
}
// namespace mmha
awq_cuda/attention/ft_attention.cpp
0 → 100644
View file @
1b0af2d3
// Adapted from NVIDIA/FasterTransformer and FlashAttention
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ft_attention.h"
#include "decoder_masked_multihead_attention.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
if (TYPE == at::ScalarType::Half) { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::BFloat16) { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::Float) { \
using scalar_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
template
<
typename
T
>
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
void
cross_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
struct
SATypeConverter
{
using
Type
=
T
;
};
template
<
>
struct
SATypeConverter
<
at
::
Half
>
{
using
Type
=
uint16_t
;
};
template
<
>
struct
SATypeConverter
<
at
::
BFloat16
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
typename
T
>
void
set_params
(
Masked_multihead_attention_params
<
T
>
&
params
,
const
size_t
batch_size
,
const
size_t
nheads
,
const
size_t
nheads_kv
,
const
size_t
memory_max_seqlen
,
const
size_t
headdim
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
T
*
q_ptr
,
T
*
k_ptr
,
T
*
v_ptr
,
T
*
k_cache_ptr
,
T
*
v_cache_ptr
,
int
*
length_per_sample
,
float
*
alibi_slopes_ptr
,
T
*
out_ptr
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
q
=
q_ptr
;
params
.
k
=
k_ptr
;
params
.
v
=
v_ptr
;
params
.
q_bias
=
nullptr
;
params
.
k_bias
=
nullptr
;
params
.
v_bias
=
nullptr
;
params
.
k_cache
=
k_cache_ptr
;
params
.
v_cache
=
v_cache_ptr
;
params
.
linear_bias_slopes
=
alibi_slopes_ptr
;
params
.
out
=
out_ptr
;
params
.
cache_indir
=
nullptr
;
params
.
stride
=
qkv_batch_stride
;
params
.
batch_size
=
batch_size
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
num_heads
=
nheads
;
params
.
num_kv_heads
=
nheads_kv
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_base
=
rotary_base
;
params
.
neox_rotary_style
=
neox_rotary_style
;
params
.
timestep
=
timestep
;
params
.
inv_sqrt_dh
=
1.
f
/
sqrt
(
float
(
headdim
));
params
.
total_padding_tokens
=
nullptr
;
params
.
masked_tokens
=
nullptr
;
params
.
prefix_prompt_lengths
=
nullptr
;
params
.
max_prefix_prompt_length
=
0
;
params
.
relative_attention_bias
=
nullptr
;
params
.
relative_attention_bias_stride
=
0
;
params
.
cross_attention_out
=
nullptr
;
params
.
max_decoder_seq_len
=
0
;
params
.
is_return_cross_attentions
=
false
;
params
.
finished
=
nullptr
;
params
.
memory_length_per_sample
=
nullptr
;
params
.
length_per_sample
=
length_per_sample
;
}
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
// neox_rotary_style = not interleaved
const
bool
neox_rotary_style
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
int
batch_size
=
v_cache
.
size
(
0
);
int
nheads
=
q
.
size
(
1
);
int
nheads_kv
=
v_cache
.
size
(
1
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
headdim
=
v_cache
.
size
(
3
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads_kv
,
memory_max_seqlen
,
headdim
);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int
packsize
=
k_cache
.
dtype
()
==
torch
::
kFloat32
?
4
:
8
;
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads_kv
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
TORCH_CHECK
(
q
.
stride
(
2
)
==
1
&&
q
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
k
.
stride
(
2
)
==
1
&&
k
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
v
.
stride
(
2
)
==
1
&&
v
.
stride
(
1
)
==
headdim
);
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
if
(
length_per_sample_
.
has_value
())
{
auto
length_per_sample
=
length_per_sample_
.
value
();
CHECK_DEVICE
(
length_per_sample
);
CHECK_SHAPE
(
length_per_sample
,
batch_size
);
CHECK_CONTIGUOUS
(
length_per_sample
);
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
}
if
(
alibi_slopes_
.
has_value
())
{
auto
alibi_slopes
=
alibi_slopes_
.
value
();
CHECK_DEVICE
(
alibi_slopes
);
CHECK_SHAPE
(
alibi_slopes
,
nheads
);
CHECK_CONTIGUOUS
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
nheads_kv
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k_cache
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v_cache
.
data_ptr
()),
length_per_sample_
.
has_value
()
?
length_per_sample_
.
value
().
data_ptr
<
int
>
()
:
nullptr
,
alibi_slopes_
.
has_value
()
?
alibi_slopes_
.
value
().
data_ptr
<
float
>
()
:
nullptr
,
reinterpret_cast
<
DataType
*>
(
out
.
data_ptr
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
masked_multihead_attention
(
params
,
stream
);
});
return
out
;
}
\ No newline at end of file
awq_cuda/attention/ft_attention.h
0 → 100644
View file @
1b0af2d3
#pragma once
#include <torch/extension.h>
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
=
0
,
const
float
rotary_base
=
10000
.
0
f
,
const
bool
neox_rotary_style
=
true
);
\ No newline at end of file
awq_cuda/position_embedding/pos_encoding.h
View file @
1b0af2d3
#pragma once
#include <torch/extension.h>
void
rotary_embedding
(
void
rotary_embedding
_neox
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
\ No newline at end of file
torch
::
Tensor
&
cos_sin_cache
);
\ No newline at end of file
awq_cuda/position_embedding/pos_encoding_kernels.cu
View file @
1b0af2d3
...
...
@@ -9,56 +9,15 @@ https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
// GPT-NeoX style rotary embedding.
x_index
=
rot_offset
;
y_index
=
embed_dim
+
rot_offset
;
cos
=
__ldg
(
cos_ptr
+
x_index
);
sin
=
__ldg
(
sin_ptr
+
x_index
);
}
else
{
// GPT-J style rotary embedding.
x_index
=
2
*
rot_offset
;
y_index
=
2
*
rot_offset
+
1
;
cos
=
__ldg
(
cos_ptr
+
x_index
/
2
);
sin
=
__ldg
(
sin_ptr
+
x_index
/
2
);
}
const
scalar_t
x
=
arr
[
x_index
];
const
scalar_t
y
=
arr
[
y_index
];
arr
[
x_index
]
=
x
*
cos
-
y
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_
kv_
heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
query_stride
,
const
int
key_stride
,
const
int
stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
...
...
@@ -66,72 +25,64 @@ __global__ void rotary_embedding_kernel(
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
int
nq
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nq
;
i
+=
blockDim
.
x
)
{
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
nk
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [num_tokens]
torch
::
Tensor
&
query
,
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [num_tokens, num_kv_heads * head_size]
void
rotary_embedding_neox
(
torch
::
Tensor
&
positions
,
// [b, num_tokens]
torch
::
Tensor
&
query
,
// [b, num_tokens, 1, num_heads, head_size]
torch
::
Tensor
&
key
,
// [b, num_tokens, 1, num_heads, head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int
num_tokens
=
query
.
size
(
0
);
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
)
;
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
1
)
/
head_size
;
int
num_kv_heads
=
key
.
size
(
1
)
/
head_size
;
int
query_stride
=
query
.
stride
(
0
);
int
key_stride
=
key
.
stride
(
0
);
int
num_heads
=
query
.
size
(
-
2
);
int
stride
=
num_heads
*
head_size
;
// TORCH_CHECK(stride == key.stride(0));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
"rotary_embedding"
,
"rotary_embedding
_neox
"
,
[
&
]
{
if
(
is_neox
)
{
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
stride
,
num_heads
,
head_size
);
});
}
\ No newline at end of file
}
awq_cuda/pybind.cpp
View file @
1b0af2d3
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply rotary embedding to query and key"
);
}
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
\ No newline at end of file
awq_cuda/quantization/gemm_cuda_gen.cu
View file @
1b0af2d3
...
...
@@ -10,6 +10,7 @@
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemm_cuda.h"
#include "dequantize.cuh"
#include <cuda_fp16.h>
...
...
@@ -439,6 +440,7 @@ torch::Tensor gemm_forward_cuda(
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
64
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
...
...
@@ -456,7 +458,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n128k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
...
...
@@ -467,7 +469,7 @@ torch::Tensor gemm_forward_cuda(
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
>>>
(
gemm_forward_4bit_cuda_m16n64k32
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
...
...
awq_cuda/quantization/gemv_cuda.cu
0 → 100644
View file @
1b0af2d3
// Inspired by https://github.com/ankan-ban/llama_cu_awq
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define PACK_FACTOR 8
#define WARP_SIZE 32
// Reduce sum within the warp using the tree reduction algorithm.
__device__
__forceinline__
float
warp_reduce_sum
(
float
sum
)
{
#pragma unroll
for
(
int
i
=
4
;
i
>=
0
;
i
--
){
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
1
<<
i
);
}
/*
// Equivalent to the following tree reduction implementation:
sum += __shfl_down_sync(0xffffffff, sum, 16);
sum += __shfl_down_sync(0xffffffff, sum, 8);
sum += __shfl_down_sync(0xffffffff, sum, 4);
sum += __shfl_down_sync(0xffffffff, sum, 2);
sum += __shfl_down_sync(0xffffffff, sum, 1);
*/
return
sum
;
}
__device__
__forceinline__
int
make_divisible
(
int
c
,
int
divisor
){
return
(
c
+
divisor
-
1
)
/
divisor
;
}
/*
Computes GEMV (group_size = 64).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g64
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
64
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
// This is essentially zeros_w.
const
int
num_groups_packed
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
// consistent with input shape
const
int
sf_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
*
PACK_FACTOR
;
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
/
2
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint64_t
packed_zeros
=
*
reinterpret_cast
<
const
uint64_t
*>
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
*
2
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
16
+
(
threadIdx
.
x
/
2
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
2
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (group_size = 128).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g128
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
128
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
const
int
num_groups_packed
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
// consistent with input shape
const
int
sf_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
)
*
PACK_FACTOR
;
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint32_t
packed_zeros
=
*
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
8
+
(
threadIdx
.
x
/
4
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
4
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
// int kernel_volume = _out_in_map.size(1);
auto
in_feats
=
reinterpret_cast
<
float4
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
uint32_t
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
uint32_t
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
// auto out_in_map = _out_in_map.data_ptr<int>();
auto
options
=
torch
::
TensorOptions
().
dtype
(
_in_feats
.
dtype
()).
device
(
_in_feats
.
device
());
// kernel is [OC, IC]
at
::
Tensor
_out_feats
=
torch
::
empty
({
num_in_feats
,
_kernel
.
size
(
0
)},
options
);
int
num_out_feats
=
_out_feats
.
size
(
-
2
);
int
num_out_channels
=
_out_feats
.
size
(
-
1
);
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
int
blockDim_z
=
num_out_feats
;
dim3
num_blocks
(
1
,
num_out_channels
/
4
,
num_out_feats
);
dim3
num_threads
(
32
,
4
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
group_size
==
64
)
{
gemv_kernel_g64
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
else
if
(
group_size
==
128
)
{
gemv_kernel_g128
<<<
num_blocks
,
num_threads
,
0
,
stream
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
return
_out_feats
;
;}
awq_cuda/quantization/gemv_cuda.h
0 → 100644
View file @
1b0af2d3
#pragma once
#include <torch/extension.h>
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
);
examples/basic_quant.py
View file @
1b0af2d3
...
...
@@ -3,7 +3,7 @@ from transformers import AutoTokenizer
model_path
=
'lmsys/vicuna-7b-v1.5'
quant_path
=
'vicuna-7b-v1.5-awq'
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
}
# Load model
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
...
...
examples/benchmark.py
0 → 100644
View file @
1b0af2d3
import
time
import
torch
import
argparse
import
numpy
as
np
import
pandas
as
pd
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
torch.cuda
import
OutOfMemoryError
def
warmup
(
model
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
next
(
model
.
parameters
()).
device
)
torch
.
mm
(
warm_up
,
warm_up
)
def
generate
(
model
,
input_ids
,
n_generate
):
context_time
=
0
generate_time
=
[]
with
torch
.
inference_mode
():
for
i
in
range
(
n_generate
):
torch
.
cuda
.
synchronize
()
start
=
time
.
time
()
if
i
==
0
:
# prefill context
inputs
=
torch
.
as_tensor
(
input_ids
,
device
=
next
(
model
.
parameters
()).
device
)
else
:
# decode tokens
inputs
=
torch
.
as_tensor
(
token
,
device
=
next
(
model
.
parameters
()).
device
)
out
=
model
(
inputs
,
use_cache
=
True
)
torch
.
cuda
.
synchronize
()
token
=
out
[
0
][:,
-
1
].
max
(
1
)[
1
].
unsqueeze
(
1
)
if
i
==
0
:
context_time
+=
time
.
time
()
-
start
else
:
generate_time
.
append
(
time
.
time
()
-
start
)
return
context_time
,
generate_time
def
run_round
(
model_path
,
quant_file
,
n_generate
,
input_ids
,
batch_size
):
print
(
f
" -- Loading model..."
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
True
,
max_new_tokens
=
n_generate
,
batch_size
=
batch_size
)
print
(
f
" -- Warming up..."
)
warmup
(
model
)
print
(
f
" -- Generating
{
n_generate
}
tokens,
{
input_ids
.
shape
[
1
]
}
in context..."
)
try
:
context_time
,
generate_time
=
generate
(
model
,
input_ids
,
n_generate
)
successful_generate
=
True
except
RuntimeError
as
ex
:
if
'cuda out of memory'
in
str
(
ex
).
lower
():
successful_generate
=
False
else
:
raise
RuntimeError
(
ex
)
device
=
next
(
model
.
parameters
()).
device
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
3
)
memory_pct
=
memory_used
/
(
torch
.
cuda
.
get_device_properties
(
device
).
total_memory
/
(
1024
**
3
))
*
100
if
successful_generate
:
# number of tokens in context / time for processing context * batch size
prefill_tokens_per_second
=
input_ids
.
shape
[
1
]
/
context_time
*
batch_size
# 1 second / median time per token in seconds * batch size
decode_tokens_per_second
=
1
/
np
.
median
(
generate_time
)
*
batch_size
print
(
f
" ** Speed (Prefill):
{
prefill_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Speed (Decode):
{
decode_tokens_per_second
:.
2
f
}
tokens/second"
)
print
(
f
" ** Max Memory (VRAM):
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
)
else
:
prefill_tokens_per_second
=
'OOM'
decode_tokens_per_second
=
'OOM'
return
{
"Batch Size"
:
batch_size
,
"Prefill Length"
:
input_ids
.
shape
[
1
],
"Decode Length"
:
n_generate
,
"Prefill tokens/s"
:
prefill_tokens_per_second
,
"Decode tokens/s"
:
decode_tokens_per_second
,
"Memory (VRAM)"
:
f
"
{
memory_used
:.
2
f
}
GB (
{
memory_pct
:.
2
f
}
%)"
},
model
.
quant_config
[
"version"
]
def
main
(
args
):
rounds
=
[
{
"context"
:
32
,
"n_generate"
:
32
},
{
"context"
:
64
,
"n_generate"
:
64
},
{
"context"
:
128
,
"n_generate"
:
128
},
{
"context"
:
256
,
"n_generate"
:
256
},
{
"context"
:
512
,
"n_generate"
:
512
},
{
"context"
:
1024
,
"n_generate"
:
1024
},
{
"context"
:
2048
,
"n_generate"
:
2048
},
]
all_stats
=
[]
tokenizer
=
AutoTokenizer
.
from_pretrained
(
args
.
model_path
,
trust_remote_code
=
True
)
for
settings
in
rounds
:
input_ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
args
.
batch_size
,
settings
[
"context"
])).
cuda
()
stats
,
model_version
=
run_round
(
args
.
model_path
,
args
.
quant_file
,
settings
[
"n_generate"
],
input_ids
,
args
.
batch_size
)
all_stats
.
append
(
stats
)
if
stats
[
"Prefill tokens/s"
]
==
'OOM'
:
break
df
=
pd
.
DataFrame
(
all_stats
)
print
(
'GPU:'
,
torch
.
cuda
.
get_device_name
())
print
(
'Model:'
,
args
.
model_path
)
print
(
'Version:'
,
model_version
)
print
(
df
.
to_markdown
(
index
=
False
))
if
__name__
==
"__main__"
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--model_path"
,
type
=
str
,
default
=
"casperhansen/vicuna-7b-v1.5-awq"
,
help
=
"path to the model"
)
parser
.
add_argument
(
"--quant_file"
,
type
=
str
,
default
=
"awq_model_w4_g128.pt"
,
help
=
"weights filename"
)
parser
.
add_argument
(
"--batch_size"
,
type
=
int
,
default
=
1
,
help
=
"weights filename"
)
args
=
parser
.
parse_args
()
main
(
args
)
\ No newline at end of file
setup.py
View file @
1b0af2d3
...
...
@@ -44,14 +44,28 @@ requirements = [
"toml"
,
"attributedict"
,
"protobuf"
,
"torchvision"
"torchvision"
,
"tabulate"
]
include_dirs
=
[]
def
get_include_dirs
():
include_dirs
=
[]
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
include_dirs
.
append
(
this_dir
)
return
include_dirs
def
get_generator_flag
():
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
return
generator_flag
def
check_dependencies
():
if
CUDA_HOME
is
None
:
...
...
@@ -77,6 +91,8 @@ def get_compute_capabilities():
return
capability_flags
check_dependencies
()
include_dirs
=
get_include_dirs
()
generator_flags
=
get_generator_flag
()
arch_flags
=
get_compute_capabilities
()
if
os
.
name
==
"nt"
:
...
...
@@ -86,8 +102,21 @@ if os.name == "nt":
}
else
:
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
]
+
arch_flags
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
,
"-DENABLE_BF16"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
]
+
arch_flags
+
generator_flags
}
extensions
=
[
...
...
@@ -97,7 +126,10 @@ extensions = [
"awq_cuda/pybind.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
)
]
...
...
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment