Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ef6b60e2
Commit
ef6b60e2
authored
Sep 08, 2023
by
Casper Hansen
Browse files
New kernels
parent
84fb7e98
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
4511 additions
and
8 deletions
+4511
-8
awq_cuda/attention/cuda_bf16_fallbacks.cuh
awq_cuda/attention/cuda_bf16_fallbacks.cuh
+257
-0
awq_cuda/attention/cuda_bf16_wrapper.h
awq_cuda/attention/cuda_bf16_wrapper.h
+23
-0
awq_cuda/attention/decoder_masked_multihead_attention.cu
awq_cuda/attention/decoder_masked_multihead_attention.cu
+154
-0
awq_cuda/attention/decoder_masked_multihead_attention.h
awq_cuda/attention/decoder_masked_multihead_attention.h
+184
-0
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+1608
-0
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
...cuda/attention/decoder_masked_multihead_attention_utils.h
+1786
-0
awq_cuda/attention/ft_attention.cpp
awq_cuda/attention/ft_attention.cpp
+182
-0
awq_cuda/attention/ft_attention.h
awq_cuda/attention/ft_attention.h
+15
-0
awq_cuda/pybind.cpp
awq_cuda/pybind.cpp
+8
-1
awq_cuda/quantization/gemv_cuda.cu
awq_cuda/quantization/gemv_cuda.cu
+247
-0
awq_cuda/quantization/gemv_cuda.h
awq_cuda/quantization/gemv_cuda.h
+9
-0
setup.py
setup.py
+38
-7
No files found.
awq_cuda/attention/cuda_bf16_fallbacks.cuh
0 → 100644
View file @
ef6b60e2
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace
fastertransformer
{
#ifdef ENABLE_BF16
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2
f_val
;
f_val
.
x
=
__low2float
(
val
);
f_val
.
y
=
__high2float
(
val
);
return
f_val
;
#else
return
__bfloat1622float2
(
val
);
#endif
}
inline
__device__
int16_t
bf1622int16
(
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2
f_val
;
f_val
.
x
=
max
(
min
(
__low2float
(
val
),
127.
f
),
-
128.
f
);
f_val
.
y
=
max
(
min
(
__high2float
(
val
),
127.
f
),
-
128.
f
);
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
int8
[
0
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
f_val
.
x
));
int8
[
1
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
f_val
.
y
));
return
int16
;
#else
val
=
__hmin2
(
val
,
make_bfloat162
(
127.
,
127.
));
val
=
__hmax2
(
val
,
make_bfloat162
(
-
128.
,
-
128.
));
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
int8
[
0
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
val
.
x
));
int8
[
1
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
val
.
y
));
return
int16
;
#endif
}
inline
__device__
__nv_bfloat162
float22bf162
(
const
float2
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__floats2bfloat162_rn
(
val
.
x
,
val
.
y
);
#else
return
__float22bfloat162_rn
(
val
);
#endif
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162
val2
;
val2
.
x
=
val
;
val2
.
y
=
val
;
return
val2
;
#else
return
__bfloat162bfloat162
(
val
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hadd2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
+
fyl
,
fxh
+
fyh
);
#else
return
__hadd2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hadd
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
+
__bfloat162float
(
y
)
);
#else
return
__hadd
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hsub2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
-
fyl
,
fxh
-
fyh
);
#else
return
__hsub2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hsub
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
-
__bfloat162float
(
y
)
);
#else
return
__hsub
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hmul2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
*
fyl
,
fxh
*
fyh
);
#else
return
__hmul2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hmul
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
*
__bfloat162float
(
y
)
);
#else
return
__hmul
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hfma2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
,
const
__nv_bfloat162
z
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
,
fzl
,
fzh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
fzl
=
__low2float
(
z
);
fzh
=
__high2float
(
z
);
return
__floats2bfloat162_rn
(
fxl
*
fyl
+
fzl
,
fxh
*
fyh
+
fzh
);
#else
return
__hfma2
(
x
,
y
,
z
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hfma
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
,
const
__nv_bfloat16
z
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
*
__bfloat162float
(
y
)
+
__bfloat162float
(
z
));
#else
return
__hfma
(
x
,
y
,
z
);
#endif
}
inline
__device__
__nv_bfloat162
bf16exp2
(
const
__nv_bfloat162
x
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);;
return
__floats2bfloat162_rn
(
expf
(
fxl
),
expf
(
fxh
));
#else
return
h2exp
(
x
);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline
__device__
__nv_bfloat162
operator
*
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
return
bf16hmul2
(
x
,
y
);
};
inline
__device__
__nv_bfloat162
operator
+
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
return
bf16hadd2
(
x
,
y
);
};
inline
__device__
__nv_bfloat162
make_bfloat162
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
__nv_bfloat162
t
;
t
.
x
=
x
;
t
.
y
=
y
;
return
t
;
}
#endif
inline
__device__
__nv_bfloat16
bf16hadd
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
+
__bfloat162float
(
b
)
+
__bfloat162float
(
c
));
#else
return
a
+
b
+
c
;
#endif
}
inline
__device__
__nv_bfloat16
bf16hadd
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
,
__nv_bfloat16
d
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
+
__bfloat162float
(
b
)
+
__bfloat162float
(
c
)
+
__bfloat162float
(
d
));
#else
return
(
__nv_bfloat16
)((
float
)
a
+
(
float
)
b
+
(
float
)
c
+
(
float
)
d
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hadd2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
return
__floats2bfloat162_rn
(
fal
+
fbl
+
fcl
,
fah
+
fbh
+
fch
);
#else
return
a
+
b
+
c
;
#endif
}
inline
__device__
__nv_bfloat16
bf16hmul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
*
__bfloat162float
(
c
));
#else
return
a
*
b
*
c
;
#endif
}
inline
__device__
__nv_bfloat162
bf16hmul2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
return
__floats2bfloat162_rn
(
fal
*
fbl
*
fcl
,
fah
*
fbh
*
fch
);
#else
return
a
*
b
*
c
;
#endif
}
inline
__device__
__nv_bfloat162
bf16hfma2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
,
__nv_bfloat162
d
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
,
fdl
,
fdh
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
fdl
=
__low2float
(
d
);
fdh
=
__high2float
(
d
);
return
__floats2bfloat162_rn
(
fal
*
fbl
*
fcl
+
fdl
,
fah
*
fbh
*
fch
+
fdh
);
#else
return
a
*
b
*
c
+
d
;
#endif
}
#endif // ENABLE_BF16
}
// namespace fastertransformer
awq_cuda/attention/cuda_bf16_wrapper.h
0 → 100644
View file @
ef6b60e2
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
awq_cuda/attention/decoder_masked_multihead_attention.cu
0 → 100644
View file @
ef6b60e2
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
if (smem_sz >= 48 * 1024) { \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
} \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
typename
KERNEL_PARAMS_TYPE
>
void
mmha_launch_kernel
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh_MAX
*
sizeof
(
T
)
/
16
;
constexpr
bool
DO_CROSS_ATTENTION
=
std
::
is_same
<
KERNEL_PARAMS_TYPE
,
Cross_multihead_attention_params
<
T
>>::
value
;
int
tlength
=
(
DO_CROSS_ATTENTION
)
?
params
.
memory_max_len
:
params
.
timestep
;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if
(
tlength
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
DO_CROSS_ATTENTION
,
stream
);
}
else
if
(
tlength
<
2048
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
DO_CROSS_ATTENTION
,
stream
);
}
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
DO_CROSS_ATTENTION
,
stream
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template
<
typename
T
,
typename
KERNEL_PARAMS_TYPE
>
void
multihead_attention_
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
)
{
switch
(
params
.
hidden_size_per_head
)
{
case
32
:
mmha_launch_kernel
<
T
,
32
,
32
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
48
:
mmha_launch_kernel
<
T
,
48
,
64
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
64
:
mmha_launch_kernel
<
T
,
64
,
64
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
80
:
mmha_launch_kernel
<
T
,
80
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
96
:
mmha_launch_kernel
<
T
,
96
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
112
:
mmha_launch_kernel
<
T
,
112
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
128
:
mmha_launch_kernel
<
T
,
128
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
160
:
mmha_launch_kernel
<
T
,
160
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
192
:
mmha_launch_kernel
<
T
,
192
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
224
:
mmha_launch_kernel
<
T
,
224
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
256
:
mmha_launch_kernel
<
T
,
256
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
default:
assert
(
false
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
float
,
Masked_multihead_attention_params
<
float
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
uint16_t
,
Masked_multihead_attention_params
<
uint16_t
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
__nv_bfloat16
,
Masked_multihead_attention_params
<
__nv_bfloat16
>>
(
params
,
stream
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
float
,
Cross_multihead_attention_params
<
float
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
uint16_t
,
Cross_multihead_attention_params
<
uint16_t
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
__nv_bfloat16
,
Cross_multihead_attention_params
<
__nv_bfloat16
>>
(
params
,
stream
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention.h
0 → 100644
View file @
ef6b60e2
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template
<
typename
T
>
struct
Multihead_attention_params_base
{
// The output buffer. Dimensions B x D.
T
*
out
=
nullptr
;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const
T
*
q
=
nullptr
,
*
q_bias
=
nullptr
;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const
T
*
k
=
nullptr
,
*
k_bias
=
nullptr
;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const
T
*
v
=
nullptr
,
*
v_bias
=
nullptr
;
// The cache for the Ks. The size must be at least B x L x D.
T
*
k_cache
=
nullptr
;
// The cache for the Vs. The size must be at least B x L x D.
T
*
v_cache
=
nullptr
;
// The indirections to use for cache when beam sampling.
const
int
*
cache_indir
=
nullptr
;
// Stride to handle the case when KQV is a single buffer
int
stride
=
0
;
// The batch size.
int
batch_size
=
0
;
// The beam width
int
beam_width
=
0
;
// The sequence length.
int
memory_max_len
=
0
;
// The number of heads (H).
int
num_heads
=
0
;
// The number of heads for KV cache.
int
num_kv_heads
=
0
;
// The hidden dimension per head (Dh).
int
hidden_size_per_head
=
0
;
// The per-head latent space reserved for rotary embeddings.
int
rotary_embedding_dim
=
0
;
bool
neox_rotary_style
=
false
;
float
rotary_base
=
0.0
f
;
// The maximum length of input sentences.
int
max_input_length
=
0
;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int
timestep
=
0
;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float
inv_sqrt_dh
=
0.0
f
;
// Used when we have some input context like gpt
const
int
*
total_padding_tokens
=
nullptr
;
const
bool
*
masked_tokens
=
nullptr
;
const
int
*
prefix_prompt_lengths
=
nullptr
;
int
max_prefix_prompt_length
=
0
;
const
T
*
relative_attention_bias
=
nullptr
;
int
relative_attention_bias_stride
=
0
;
// The slope per head of linear position bias to attention score (H).
const
float
*
linear_bias_slopes
=
nullptr
;
const
T
*
ia3_key_weights
=
nullptr
;
const
T
*
ia3_value_weights
=
nullptr
;
const
int
*
ia3_tasks
=
nullptr
;
const
float
*
qkv_scale_out
=
nullptr
;
const
float
*
attention_out_scale
=
nullptr
;
int
int8_mode
=
0
;
};
template
<
typename
T
,
bool
CROSS_ATTENTION
>
struct
Multihead_attention_params
:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
// will need it here till if constexpr in c++17
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
typename
T
>
struct
Multihead_attention_params
<
T
,
true
>:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
class
T
>
using
Masked_multihead_attention_params
=
Multihead_attention_params
<
T
,
false
>
;
template
<
class
T
>
using
Cross_multihead_attention_params
=
Multihead_attention_params
<
T
,
true
>
;
template
<
typename
T
>
struct
outputCrossAttentionParam
{
// max decoder output length
int
max_decoder_seq_len
=
0
;
T
*
cross_attention_out
=
nullptr
;
bool
is_return_cross_attentions
=
false
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
0 → 100644
View file @
ef6b60e2
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <assert.h>
#include <float.h>
#include <type_traits>
// #define MMHA_USE_HMMA_FOR_REDUCTION
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
#define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
// Does not seem to improve the accuracy
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#endif
namespace
mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
//
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
// 64, 128 and 256 threads per block.
//
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
// cache buffer helps with memory accesses and contains keys with bias.
//
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
// values for x are chosen to create chunks of 16 bytes.
//
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
//
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
// shared memory.
//
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
// except for the current timestep. The layout of the cache buffer for the values is much simpler
// as it is [B, H, L, Dh].
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Dh
>
struct
Qk_vec_
{
};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
256
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
32
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
64
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
128
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
256
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{
};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_
<
float
,
1
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
4
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
2
>
{
using
Type
=
uint2
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
1
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
4
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
2
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
1
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
V_VEC_SIZE
>
struct
V_vec_
{
};
template
<
>
struct
V_vec_
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_
<
float
,
4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
8
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
2
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
4
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
8
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template
<
typename
T
>
struct
Qk_vec_acum_fp32_
{
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
template
<
>
struct
Qk_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
K_vec_acum_fp32_
{
};
template
<
>
struct
K_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
V_vec_acum_fp32_
{
};
template
<
>
struct
V_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
#ifdef ENABLE_BF16
template
<
>
struct
V_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
V_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
#endif // ENABLE_BF16
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
K_vec_acum
=
typename
K_vec_acum_fp32_
<
K_vec
>::
Type
;
#else
using
K_vec_acum
=
K_vec
;
#endif
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
K_vec
,
K_vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREADS_PER_KEY
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
hmma_fp32
(
const
uint2
&
a
,
uint32_t
b
)
{
float4
c
;
float
zero
=
0.
f
;
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32
\n
"
" {%0, %1, %2, %3},
\n
"
" {%4, %5},
\n
"
" {%6},
\n
"
" {%7, %7, %7, %7};
\n
"
:
"=f"
(
c
.
x
),
"=f"
(
c
.
y
),
"=f"
(
c
.
z
),
"=f"
(
c
.
w
)
:
"r"
(
a
.
x
)
"r"
(
a
.
y
),
"r"
(
b
),
"f"
(
zero
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
float
qk_hmma_dot_
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
K_vec_acum
=
typename
K_vec_acum_fp32_
<
uint32_t
>::
Type
;
#else
using
K_vec_acum
=
uint32_t
;
#endif
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
uint32_t
,
uint32_t
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t
qk_vec_
=
float2_to_half2
(
qk_vec
);
return
hmma_fp32
(
make_uint2
(
qk_vec_
,
0u
),
0x3c003c00u
).
x
;
#else
return
hmma_fp32
(
make_uint2
(
qk_vec
,
0u
),
0x3c003c00u
).
x
;
#endif
#else
return
0.
f
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Qk_dot
<
uint16_t
,
4
>
{
template
<
int
N
>
static
inline
__device__
float
dot
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return
qk_hmma_dot_
(
q
,
k
);
#else
return
qk_dot_
<
4
>
(
q
,
k
);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_PER_BLOCK
,
int
WARP_SIZE
=
32
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The warps compute the final sums.
if
(
lane
<
WARPS_PER_BLOCK
)
{
sum
=
red_smem
[
lane
];
}
// Parallel reduction inside the warp.
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Broadcast to other threads.
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float
&
dst
,
float
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint16_t
&
dst
,
float
src
)
{
dst
=
float_to_half
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint32_t
&
dst
,
float2
src
)
{
dst
=
float2_to_half2
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
void
convert_from_float
(
__nv_bfloat16
&
dst
,
float
src
)
{
dst
=
__float2bfloat16
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
=
__float22bfloat162_rn
(
src
);
#else
dst
=
__floats2bfloat162_rn
(
src
.
x
,
src
.
y
);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint2
&
dst
,
Float4_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint2
&
dst
,
float4
src
)
{
convert_from_float
(
dst
,
Float4_
{
make_float2
(
src
.
x
,
src
.
y
),
make_float2
(
src
.
z
,
src
.
w
)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint4
&
dst
,
Float8_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
dst
.
z
=
float2_to_half2
(
src
.
z
);
dst
.
w
=
float2_to_half2
(
src
.
w
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
void
convert_from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
#else
dst
.
x
=
__floats2bfloat162_rn
(
src
.
x
.
x
,
src
.
x
.
y
);
dst
.
y
=
__floats2bfloat162_rn
(
src
.
y
.
x
,
src
.
y
.
y
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
bf16_4_t
&
dst
,
float4
src
)
{
convert_from_float
(
dst
,
Float4_
{
make_float2
(
src
.
x
,
src
.
y
),
make_float2
(
src
.
z
,
src
.
w
)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
#else
dst
.
x
=
__floats2bfloat162_rn
(
src
.
x
.
x
,
src
.
x
.
y
);
dst
.
y
=
__floats2bfloat162_rn
(
src
.
y
.
x
,
src
.
y
.
y
);
dst
.
z
=
__floats2bfloat162_rn
(
src
.
z
.
x
,
src
.
z
.
y
);
dst
.
w
=
__floats2bfloat162_rn
(
src
.
w
.
x
,
src
.
w
.
y
);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float2
&
dst
,
float2
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float4
&
dst
,
float4
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
convert_to_float
(
float4
u
)
{
return
u
.
x
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
convert_to_float
(
uint4
u
)
{
float2
tmp
=
half2_to_float2
(
u
.
x
);
return
tmp
.
x
;
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
cast_to_float
(
float
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
float2
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
cast_to_float
(
float4
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
Float4_
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
cast_to_float
(
Float8_
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
uint32_t
u
)
{
return
half2_to_float2
(
u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
uint2
u
)
{
Float4_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
return
tmp
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
cast_to_float
(
uint4
u
)
{
Float8_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
tmp
.
z
=
half2_to_float2
(
u
.
z
);
tmp
.
w
=
half2_to_float2
(
u
.
w
);
return
tmp
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
float_from_int8
(
int8_t
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
float_from_int8
(
int16_t
u
)
{
union
{
int16_t
int16
;
int8_t
int8
[
2
];
};
int16
=
u
;
return
make_float2
(
int8
[
0
],
int8
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
float_from_int8
(
int32_t
u
)
{
union
{
int32_t
int32
;
int8_t
int8
[
4
];
};
int32
=
u
;
return
make_float4
(
int8
[
0
],
int8
[
1
],
int8
[
2
],
int8
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// clang-format off
inline
__device__
Float8_
float_from_int8
(
int64_t
u
)
{
union
{
int64_t
int64
;
int16_t
int16
[
4
];
};
int64
=
u
;
return
Float8_
{
float_from_int8
(
int16
[
0
]),
float_from_int8
(
int16
[
1
]),
float_from_int8
(
int16
[
2
]),
float_from_int8
(
int16
[
3
])};
}
// clang-format on
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int8_t
cast_to_int8
(
float
val
)
{
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
asm
volatile
(
"cvt.rni.sat.s8.f32 %0, %1;"
:
"=h"
(
int16
)
:
"f"
(
val
));
return
int8
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int32_t
cast_to_int8
(
float4
val
)
{
union
{
int8_t
int8
[
4
];
int32_t
int32
;
};
int8
[
0
]
=
cast_to_int8
(
val
.
x
);
int8
[
1
]
=
cast_to_int8
(
val
.
y
);
int8
[
2
]
=
cast_to_int8
(
val
.
z
);
int8
[
3
]
=
cast_to_int8
(
val
.
w
);
return
int32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int64_t
cast_to_int8
(
Float8_
val
)
{
union
{
int8_t
int8
[
8
];
int64_t
int64
;
};
int8
[
0
]
=
cast_to_int8
(
val
.
x
.
x
);
int8
[
1
]
=
cast_to_int8
(
val
.
x
.
y
);
int8
[
2
]
=
cast_to_int8
(
val
.
y
.
x
);
int8
[
3
]
=
cast_to_int8
(
val
.
y
.
y
);
int8
[
4
]
=
cast_to_int8
(
val
.
z
.
x
);
int8
[
5
]
=
cast_to_int8
(
val
.
z
.
y
);
int8
[
6
]
=
cast_to_int8
(
val
.
w
.
x
);
int8
[
7
]
=
cast_to_int8
(
val
.
w
.
y
);
return
int64
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
bool
DO_CROSS_ATTENTION
>
inline
size_t
smem_size_in_bytes
(
const
Multihead_attention_params
<
T
,
DO_CROSS_ATTENTION
>&
params
,
int
threads_per_value
,
int
threads_per_block
)
{
// The amount of shared memory needed to store the Q*K^T values in float.
const
int
max_timesteps
=
min
(
params
.
timestep
,
params
.
memory_max_len
);
size_t
qk_sz
=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
16
:
div_up
(
max_timesteps
+
1
,
4
)
*
16
;
// The extra memory needed if we are not using floats for the final logits.
size_t
logits_sz
=
0
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
// TDOD
logits_sz
=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
4
*
sizeof
(
T
)
:
div_up
(
max_timesteps
+
1
,
4
)
*
4
*
sizeof
(
T
);
}
#endif
// The total size needed during softmax.
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
// The number of partial rows to reduce in the final reduction.
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
// The amount of storage needed to finalize the outputs.
size_t
red_sz
=
rows_per_red
*
params
.
hidden_size_per_head
*
sizeof
(
T
)
/
2
;
size_t
transpose_rotary_size
=
0
;
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
transpose_rotary_size
=
2
*
params
.
rotary_embedding_dim
*
sizeof
(
T
);
}
// The max.
return
max
(
max
(
softmax_sz
,
red_sz
),
transpose_rotary_size
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
constexpr
uint32_t
shfl_mask
(
int
threads
)
{
return
threads
==
32
?
uint32_t
(
-
1
)
:
(
1u
<<
threads
)
-
1u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The type of the inputs. Supported types: float and half.
typename
T
,
// The hidden dimension per head.
int
Dh
,
int
Dh_MAX
,
// The number of threads per key.
int
THREADS_PER_KEY
,
// The number of threads per value.
int
THREADS_PER_VALUE
,
// The number of threads in a threadblock.
int
THREADS_PER_BLOCK
,
bool
DO_CROSS_ATTENTION
>
__global__
void
masked_multihead_attention_kernel
(
Multihead_attention_params
<
T
,
DO_CROSS_ATTENTION
>
params
)
{
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
static_assert
(
Dh_MAX
%
THREADS_PER_KEY
==
0
,
""
);
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
static_assert
(
Dh_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
// The size of a warp.
constexpr
int
WARP_SIZE
=
32
;
// The number of warps in a threadblock.
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
extern
__shared__
char
smem_
[];
// The shared memory for the Q*K^T values and partial logits in softmax.
float
*
qk_smem
=
reinterpret_cast
<
float
*>
(
smem_
);
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
char
*
logits_smem_
=
smem_
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
// TODO - change to tlength
const
int
max_timesteps
=
min
(
params
.
timestep
,
params
.
memory_max_len
);
logits_smem_
+=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
16
:
div_up
(
max_timesteps
+
1
,
4
)
*
16
;
}
T
*
logits_smem
=
reinterpret_cast
<
T
*>
(
logits_smem_
);
#else
float
*
logits_smem
=
reinterpret_cast
<
float
*>
(
logits_smem_
);
#endif
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
// A vector of Q or K elements for the current timestep.
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
// Use alignment for safely casting the shared buffers as Qk_vec.
// Shared memory to store Q inputs.
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
// This is one of the reasons we should have a separate kernel for cross attention
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
bias_smem
[
DO_CROSS_ATTENTION
?
Dh_MAX
:
1
];
// A vector of Q or K elements for the current timestep.
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
// The number of elements per vector.
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
// We will use block wide reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
// The number of vectors per warp.
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.
// The number of elements in a chunk of 16B (that's the x in the above formula).
constexpr
int
QK_ELTS_IN_16B
=
16
/
sizeof
(
T
);
// The number of K vectors in 16B.
constexpr
int
QK_VECS_IN_16B
=
16
/
sizeof
(
Qk_vec
);
// The batch/beam idx
const
int
bi
=
blockIdx
.
y
;
if
(
params
.
finished
!=
nullptr
&&
params
.
finished
[
bi
]
==
true
)
{
return
;
}
// The beam idx
const
int
beami
=
bi
%
params
.
beam_width
;
// The "beam-aware" batch idx
const
int
bbi
=
bi
/
params
.
beam_width
;
// The head.
const
int
num_kv_heads
=
params
.
num_kv_heads
;
const
int
kv_rep
=
(
params
.
num_heads
/
num_kv_heads
);
const
int
hi
=
blockIdx
.
x
;
const
int
hi_kv
=
hi
/
kv_rep
;
// Combine the batch and the head indices.
const
int
bhi
=
bi
*
params
.
num_heads
+
hi
;
const
int
bhi_kv
=
bi
*
(
params
.
num_heads
/
kv_rep
)
+
hi_kv
;
// Combine the "beam-aware" batch idx and the head indices.
const
int
bbhi
=
bbi
*
params
.
beam_width
*
params
.
num_heads
+
hi
;
const
int
bbhi_kv
=
bbi
*
params
.
beam_width
*
(
params
.
num_heads
/
kv_rep
)
+
hi_kv
;
// The thread in the block.
const
int
tidx
=
threadIdx
.
x
;
const
bool
handle_kv
=
!
DO_CROSS_ATTENTION
||
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
);
// Every kv_rep threads have the same kv_cache values. So only the first one writes back.
const
int
write_kv_cache
=
handle_kv
&&
(
hi
%
kv_rep
==
0
);
// While doing the product Q*K^T for the different keys we track the max.
float
qk_max
=
-
FLT_MAX
;
float
qk
=
0.0
F
;
// int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const
int
q_base_offset
=
bi
*
params
.
stride
+
hi
*
Dh
;
const
int
k_base_offset
=
bi
*
params
.
stride
+
hi_kv
*
Dh
;
const
int
v_base_offset
=
k_base_offset
;
const
size_t
bi_seq_len_offset
=
bi
*
params
.
memory_max_len
;
// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
int
tlength
=
(
DO_CROSS_ATTENTION
)
?
params
.
memory_length_per_sample
[
bi
]
-
1
:
(
params
.
length_per_sample
==
nullptr
)
?
params
.
timestep
:
params
.
length_per_sample
[
bi
]
+
params
.
max_prefix_prompt_length
;
const
int
first_step
=
max
(
0
,
tlength
+
1
-
params
.
memory_max_len
);
const
int
tlength_circ
=
tlength
%
params
.
memory_max_len
;
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const
bool
is_masked
=
tidx
>=
QK_VECS_PER_WARP
;
// The offset in the Q and K buffer also accounts for the batch.
// int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
int
q_offset
=
q_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
k_offset
=
k_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
v_offset
=
k_offset
;
// The offset in the bias buffer.
// int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int
q_bias_offset
=
hi
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
k_bias_offset
=
hi_kv
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
v_bias_offset
=
k_bias_offset
;
const
bool
do_ia3
=
handle_kv
&&
params
.
ia3_tasks
!=
nullptr
;
const
int
ia3_task_id
=
do_ia3
?
params
.
ia3_tasks
[
bbi
]
:
0
;
// Trigger the loads from the Q and K buffers.
Qk_vec
q
;
zero
(
q
);
if
(
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
))
{
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
q_scaling
=
params
.
qkv_scale_out
[
0
];
const
auto
q_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
q
)[
q_offset
]);
convert_from_float
(
q
,
mul
<
Packed_Float_t
,
float
>
(
q_scaling
,
float_from_int8
(
q_quant
)));
}
else
{
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q
[
q_offset
]);
}
}
Qk_vec
k
;
zero
(
k
);
if
(
DO_CROSS_ATTENTION
)
{
// The 16B chunk written by the thread.
int
co
=
tidx
/
QK_VECS_IN_16B
;
// The position of the thread in that 16B chunk.
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength
*
QK_ELTS_IN_16B
+
ci
;
k
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
:
k
;
}
else
{
if
(
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
))
{
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
k_scaling
=
params
.
qkv_scale_out
[
1
];
const
auto
k_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
k
)[
k_offset
]);
convert_from_float
(
k
,
mul
<
Packed_Float_t
,
float
>
(
k_scaling
,
float_from_int8
(
k_quant
)));
}
else
{
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k
[
k_offset
]);
}
}
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
!
is_masked
&&
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
q_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q_bias
[
q_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
if
(
handle_kv
)
{
k_bias
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
k_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_bias
[
k_bias_offset
])
:
k_bias
;
}
// Computes the Q/K values with bias.
q
=
add
(
q
,
q_bias
);
if
(
handle_kv
)
{
k
=
add
(
k
,
k_bias
);
}
if
(
do_ia3
&&
!
is_masked
)
{
k
=
mul
<
Qk_vec
,
Qk_vec
,
Qk_vec
>
(
k
,
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
ia3_key_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
tidx
*
QK_VEC_SIZE
]));
}
// Padded len
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
handle_kv
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
}
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
const
bool
do_rotary
=
!
is_masked
&&
QK_VEC_SIZE
*
tidx
<
params
.
rotary_embedding_dim
;
T
*
q_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
T
*
k_smem
=
q_smem
+
params
.
rotary_embedding_dim
;
const
int
half_rotary_dim
=
params
.
rotary_embedding_dim
/
2
;
const
int
half_idx
=
(
tidx
*
QK_VEC_SIZE
)
/
half_rotary_dim
;
const
int
intra_half_idx
=
(
tidx
*
QK_VEC_SIZE
)
%
half_rotary_dim
;
const
int
smem_pitch
=
half_rotary_dim
;
// TODO: adjust for bank conflicts
assert
(
half_rotary_dim
%
QK_VEC_SIZE
==
0
);
if
(
do_rotary
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
q_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
)
=
q
;
if
(
handle_kv
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
k_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
)
=
k
;
}
}
__syncthreads
();
const
int
transpose_idx
=
half_idx
*
(
half_rotary_dim
/
2
)
+
intra_half_idx
/
2
;
constexpr
int
tidx_factor
=
(
QK_VEC_SIZE
>
1
)
?
QK_VEC_SIZE
/
2
:
1
;
if
(
do_rotary
)
{
mmha
::
vec_from_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
if
(
handle_kv
)
{
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
}
__syncthreads
();
if
(
do_rotary
)
{
q
=
*
reinterpret_cast
<
Qk_vec
*>
(
q_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
);
if
(
handle_kv
)
{
k
=
*
reinterpret_cast
<
Qk_vec
*>
(
k_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
);
}
}
__syncthreads
();
}
if
(
!
is_masked
)
{
// Store the Q values to shared memory.
*
reinterpret_cast
<
Qk_vec
*>
(
&
q_smem
[
tidx
*
QK_VEC_SIZE
])
=
q
;
// Store Dh values of k_bias into smem, since will need to add later
// if params.timestep == 0
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
bias_smem
[
tidx
*
QK_VEC_SIZE
])
=
k_bias
;
}
// Write the K values to the global memory cache.
//
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
// system. We designed it this way as it allows much better memory loads (and there are many
// more loads) + the stores are really "write and forget" since we won't need the ack before
// the end of the kernel. There's plenty of time for the transactions to complete.
// The 16B chunk written by the thread.
int
co
=
tidx
/
QK_VECS_IN_16B
;
// The position of the thread in that 16B chunk.
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength_circ
*
QK_ELTS_IN_16B
+
ci
;
if
(
write_kv_cache
)
{
// Trigger the stores to global memory.
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
=
k
;
}
}
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
Qk_vec_acum
=
typename
Qk_vec_acum_fp32_
<
Qk_vec
>::
Type
;
#else
using
Qk_vec_acum
=
Qk_vec
;
#endif
qk
=
dot
<
Qk_vec_acum
,
Qk_vec
>
(
q
,
k
);
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
#pragma unroll
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
if
(
tidx
==
0
)
{
// Normalize qk.
qk
*=
params
.
inv_sqrt_dh
;
if
(
params
.
relative_attention_bias
!=
nullptr
)
{
// TODO (Haotian): check whether we should replace hi with hi_kv,
// although params.relative_attention_bias is usually not used.
qk
=
add
(
qk
,
params
.
relative_attention_bias
[
hi
*
params
.
relative_attention_bias_stride
*
params
.
relative_attention_bias_stride
+
(
tlength
-
padd_len
)
*
params
.
relative_attention_bias_stride
+
(
tlength
-
padd_len
)]);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
qk_max
=
qk
;
qk_smem
[
tlength
-
first_step
]
=
qk
;
// qk_smem[params.timestep] = qk;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The type of queries and keys for the math in the Q*K^T product.
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
// The number of elements per vector.
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert
(
Dh_MAX
%
K_VEC_SIZE
==
0
,
""
);
// The number of elements per thread.
constexpr
int
K_ELTS_PER_THREAD
=
Dh_MAX
/
THREADS_PER_KEY
;
// The number of vectors per thread.
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
// The position the first key loaded by each thread from the cache buffer (for this B * H).
int
ko
=
tidx
/
THREADS_PER_KEY
;
// The position of the thread in the chunk of keys.
int
ki
=
tidx
%
THREADS_PER_KEY
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
);
// Load the Q values from shared memory. The values are reused during the loop on K.
K_vec
q_vec
[
K_VECS_PER_THREAD
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
q_vec
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
q_smem
[
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
K_vec
k_bias_vec
[
DO_CROSS_ATTENTION
?
K_VECS_PER_THREAD
:
1
];
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
k_bias_vec
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
bias_smem
[
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
}
// The number of timesteps loaded per iteration.
constexpr
int
K_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_KEY
;
// The number of keys per warp.
constexpr
int
K_PER_WARP
=
WARP_SIZE
/
THREADS_PER_KEY
;
// The base pointer for the key in the cache buffer.
T
*
k_cache
=
&
params
.
k_cache
[
bhi_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
k_cache_batch
=
&
params
.
k_cache
[
bbhi_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int
ti_end
=
div_up
(
tlength
-
first_step
,
K_PER_WARP
)
*
K_PER_WARP
+
first_step
;
// prefix prompt length if has
const
int
prefix_prompt_length
=
(
params
.
prefix_prompt_lengths
==
nullptr
)
?
0
:
params
.
prefix_prompt_lengths
[
bi
];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const
bool
has_beams
=
params
.
cache_indir
!=
nullptr
;
const
int
*
beam_indices
=
has_beams
?
&
params
.
cache_indir
[
bi_seq_len_offset
]
:
nullptr
;
for
(
int
ti
=
first_step
+
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
const
int
ti_circ
=
ti
%
params
.
memory_max_len
;
// The keys loaded from the key cache.
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
memory_max_len
+
ti_circ
;
// if( ti < params.timestep ) {
const
bool
within_bounds
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
memory_max_len
);
if
(
ti
<
tlength
)
{
if
(
!
within_bounds
)
{
k
[
ii
]
=
k_vec_zero
;
}
else
{
if
(
has_beams
)
{
const
int
beam_offset
=
beam_indices
[
ti_circ
]
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache_batch
[
beam_offset
+
jj
*
QK_ELTS_IN_16B
]);
}
else
{
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache_batch
[
jj
*
QK_ELTS_IN_16B
]);
}
}
// add bias and update k_cache
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
k
[
ii
]
=
add
(
k
[
ii
],
k_bias_vec
[
ii
]);
if
(
do_ia3
)
{
k
[
ii
]
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
k
[
ii
],
*
reinterpret_cast
<
const
K_vec
*>
(
&
params
.
ia3_key_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]));
}
if
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
memory_max_len
)
{
*
reinterpret_cast
<
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
=
k
[
ii
];
}
}
}
}
// Perform the dot product and normalize qk.
//
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q_vec
,
k
)
*
params
.
inv_sqrt_dh
;
bool
is_mask
=
(
params
.
masked_tokens
!=
nullptr
)
&&
params
.
masked_tokens
[
bi_seq_len_offset
+
ti
];
// Store the product to shared memory. There's one qk value per timestep. Update the max.
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
if
(
ti
<
tlength
&&
tidx
%
THREADS_PER_KEY
==
0
)
{
if
(
params
.
relative_attention_bias
!=
nullptr
)
{
qk
=
add
(
qk
,
params
.
relative_attention_bias
[
hi
*
params
.
relative_attention_bias_stride
*
params
.
relative_attention_bias_stride
+
tlength
*
params
.
relative_attention_bias_stride
+
ti
]);
}
if
(
params
.
linear_bias_slopes
!=
nullptr
)
{
// Apply the linear position bias: (ki - qi) * slope[hi].
// The padding token locates between the input context and the generated tokens.
// We need to remove the number of padding tokens in the distance computation.
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
// token: i i i i p p p o o o where i=input, p=pad, o=output.
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
int
max_context_length
=
params
.
max_prefix_prompt_length
+
params
.
max_input_length
;
float
dist
=
(
ti
<
max_context_length
?
ti
+
padd_len
:
ti
)
-
tlength
;
qk
+=
mul
<
float
,
float
,
float
>
(
params
.
linear_bias_slopes
[
hi
],
dist
);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
qk_max
=
is_mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_smem
[
ti
-
first_step
]
=
qk
;
}
}
// Perform the final reduction to compute the max inside each warp.
//
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
// group so it's not needed to run the reduction inside the group (again).
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREADS_PER_KEY
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Decompose the thread index into warp and lane.
const
int
warp
=
tidx
/
WARP_SIZE
;
const
int
lane
=
tidx
%
WARP_SIZE
;
// The warp leader writes the max to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
qk_max
;
}
// Make sure the products are in shared memory.
__syncthreads
();
// The warps finalize the reduction.
qk_max
=
lane
<
WARPS_PER_BLOCK
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Broadcast to all the threads in the warp.
qk_max
=
__shfl_sync
(
uint32_t
(
-
1
),
qk_max
,
0
);
// Compute the logits and start the sum.
float
sum
=
0.
f
;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for
(
int
ti
=
first_step
+
tidx
;
ti
<=
tlength
;
ti
+=
THREADS_PER_BLOCK
)
{
bool
is_mask
=
(
params
.
masked_tokens
!=
nullptr
)
&&
params
.
masked_tokens
[
bi_seq_len_offset
+
ti
];
float
logit
=
is_mask
?
0.
f
:
__expf
(
qk_smem
[
ti
-
first_step
]
-
qk_max
);
sum
+=
logit
;
qk_smem
[
ti
-
first_step
]
=
logit
;
}
// Compute the sum.
sum
=
block_sum
<
WARPS_PER_BLOCK
>
(
&
red_smem
[
WARPS_PER_BLOCK
],
sum
);
// Normalize the logits.
float
inv_sum
=
__fdividef
(
1.
f
,
sum
+
1.e-6
f
);
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
const
size_t
cross_attention_out_offset
=
params
.
is_return_cross_attentions
?
bhi_kv
*
params
.
max_decoder_seq_len
*
params
.
memory_max_len
+
params
.
timestep
*
params
.
memory_max_len
:
0
;
for
(
int
ti
=
first_step
+
tidx
;
ti
<=
tlength
;
ti
+=
THREADS_PER_BLOCK
)
{
float
logit
=
qk_smem
[
ti
-
first_step
]
*
inv_sum
;
if
(
params
.
is_return_cross_attentions
)
{
params
.
cross_attention_out
[
cross_attention_out_offset
+
ti
]
=
logit
;
}
convert_from_float
(
logits_smem
[
ti
-
first_step
],
logit
);
}
// Put Values part below so we leverage __syncthreads
// from the previous step
// The number of elements per vector.
constexpr
int
V_VEC_SIZE
=
Dh_MAX
/
THREADS_PER_VALUE
;
// A vector of V elements for the current timestep.
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
// The value computed by this thread.
int
vo
=
tidx
/
THREADS_PER_VALUE
;
// The hidden dimensions computed by this particular thread.
int
vi
=
tidx
%
THREADS_PER_VALUE
*
V_VEC_SIZE
;
// The base pointer for the value in the cache buffer.
T
*
v_cache
=
&
params
.
v_cache
[
bhi_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
v_cache_batch
=
&
params
.
v_cache
[
bbhi_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// The number of values processed per iteration of the loop.
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
// One group of threads computes the product(s) for the current timestep.
V_vec
v_bias
;
zero
(
v_bias
);
// if( vo == params.timestep % V_PER_ITER ) {
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
if
(
handle_kv
)
{
if
(
vo
==
tlength
%
V_PER_ITER
)
{
// Trigger the loads from the V bias buffer.
if
(
params
.
v_bias
!=
nullptr
)
{
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v_bias
[
hi_kv
*
Dh
+
vi
]);
}
if
(
DO_CROSS_ATTENTION
)
{
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
])
=
v_bias
;
}
}
}
}
// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads
();
// Values continued
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using
V_vec_acum
=
typename
V_vec_acum_fp32_
<
V_vec
>::
Type
;
#else
using
V_vec_acum
=
V_vec
;
#endif
// The partial outputs computed by each thread.
V_vec_acum
out
;
zero
(
out
);
// Loop over the timesteps to compute the partial outputs.
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
first_step
+
vo
;
ti
<
tlength
;
ti
+=
V_PER_ITER
)
{
const
int
ti_circ
=
ti
%
params
.
memory_max_len
;
// Fetch offset based on cache_indir when beam sampling
const
int
beam_src
=
(
params
.
cache_indir
!=
nullptr
)
?
params
.
cache_indir
[
bi_seq_len_offset
+
ti_circ
]
:
0
;
const
int
beam_offset
=
beam_src
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
// Load the values from the cache.
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache_batch
[
beam_offset
+
ti_circ
*
Dh
]);
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
v
=
add
(
v
,
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
]));
if
(
do_ia3
)
{
v
=
mul
<
V_vec
,
V_vec
,
V_vec
>
(
v
,
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
ia3_value_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
vi
]));
}
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
ti
*
Dh
])
=
v
;
}
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float
logit
=
logits_smem
[
ti
-
first_step
];
out
=
fma
(
logit
,
cast_to_float
(
v
),
out
);
#else
T
logit
=
logits_smem
[
ti
-
first_step
];
// Update the partial sums.
out
=
fma
(
logit
,
v
,
out
);
#endif
}
}
// One group of threads computes the product(s) for the current timestep.
// if( vo == params.timestep % V_PER_ITER ) {
if
(
vo
==
tlength
%
V_PER_ITER
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
;
if
(
DO_CROSS_ATTENTION
)
{
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
tlength
*
Dh
]);
}
else
{
// Trigger the loads from the V buffer.
const
auto
v_offset
=
v_base_offset
+
vi
;
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
V_vec
>::
value
>::
type
;
const
auto
v_scaling
=
params
.
qkv_scale_out
[
2
];
const
auto
v_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
v
)[
v_offset
]);
convert_from_float
(
v
,
mul
<
Packed_Float_t
,
float
>
(
v_scaling
,
float_from_int8
(
v_quant
)));
}
else
{
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v
[
v_offset
]);
}
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi*Dh + vi]);
}
// Compute the V values with bias.
v
=
add
(
v
,
v_bias
);
if
(
write_kv_cache
)
{
if
(
do_ia3
)
{
v
=
mul
<
V_vec
,
V_vec
,
V_vec
>
(
v
,
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
ia3_value_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
vi
]));
}
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
v
;
}
// Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out
=
fma
(
logits_smem
[
tlength
-
first_step
],
cast_to_float
(
v
),
out
);
#else
// out = fma(logits_smem[params.timestep], v, out);
out
=
fma
(
logits_smem
[
tlength
-
first_step
],
v
,
out
);
#endif
}
// Make sure we can start writing to shared memory.
__syncthreads
();
// Run the final reduction amongst the different groups computing different partial outputs.
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
#pragma unroll
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
// The midpoint in the number of active groups.
int
midpoint
=
active_groups
/
2
;
// The upper part of active threads store to shared memory.
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
}
__syncthreads
();
// The bottom warps update their values.
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
// Output the final values.
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec_acum
>::
value
>::
type
;
out
=
mul
<
V_vec_acum
,
float
>
(
*
params
.
attention_out_scale
,
out
);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
(
reinterpret_cast
<
int8_t
*>
(
params
.
out
)[
bhi
*
Dh
+
vi
]))
=
cast_to_int8
(
out
);
}
else
{
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
}
#else
// TODO: support int8_mode?
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
])
=
out
;
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace mmha
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
typename
KERNEL_PARAMS_TYPE
>
void
mmha_launch_kernel
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
);
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
0 → 100644
View file @
ef6b60e2
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <stdint.h>
using
namespace
fastertransformer
;
namespace
mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Float8_
{
float2
x
;
float2
y
;
float2
z
;
float2
w
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Float4_
{
float2
x
;
float2
y
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct
bf16_4_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
bf16_8_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
__nv_bfloat162
z
;
__nv_bfloat162
w
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
num_elems
;
template
<
>
struct
num_elems
<
float
>
{
static
constexpr
int
value
=
1
;
};
template
<
>
struct
num_elems
<
float2
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
float4
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
Float4_
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
Float8_
>
{
static
constexpr
int
value
=
8
;
};
template
<
>
struct
num_elems
<
uint32_t
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
uint2
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
uint4
>
{
static
constexpr
int
value
=
8
;
};
#ifdef ENABLE_BF16
template
<
>
struct
num_elems
<
__nv_bfloat162
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
bf16_4_t
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
bf16_8_t
>
{
static
constexpr
int
value
=
8
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
N
>
struct
packed_type
;
template
<
typename
T
>
struct
packed_type
<
T
,
1
>
{
using
type
=
T
;
};
template
<
>
struct
packed_type
<
int8_t
,
2
>
{
using
type
=
int16_t
;
};
template
<
>
struct
packed_type
<
int8_t
,
4
>
{
using
type
=
int32_t
;
};
template
<
>
struct
packed_type
<
int8_t
,
8
>
{
using
type
=
int64_t
;
};
template
<
>
struct
packed_type
<
float
,
2
>
{
using
type
=
float2
;
};
template
<
>
struct
packed_type
<
float
,
4
>
{
using
type
=
float4
;
};
template
<
>
struct
packed_type
<
float
,
8
>
{
using
type
=
Float8_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
float
b
)
{
return
a
+
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
add
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
add
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
a
+
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
bf16hadd2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
add
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint16_t
add
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"add.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
add
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
add
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
add
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
float zero = 0.f;
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
));
#endif
return
tmp
.
u16
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
float2_to_half2
(
float2
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
tmp
.
u32
)
:
"f"
(
f
.
y
),
"f"
(
f
.
x
));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
.
x
));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
1
])
:
"f"
(
f
.
y
));
#endif
return
tmp
.
u32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
half2_to_float2
(
uint32_t
v
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
v
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
uint16_t
b
)
{
return
a
+
half_to_float
(
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
add
(
float
a
,
__nv_bfloat16
b
)
{
return
a
+
__bfloat162float
(
b
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
add
(
uint32_t
a
,
float2
fb
)
{
float2
fa
=
half2_to_float2
(
a
);
return
add
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
add
(
uint2
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
add
(
uint4
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
h0_h0
(
uint16_t
a
)
{
uint32_t
b
;
asm
volatile
(
"mov.b32 %0, {%1, %1};"
:
"=r"
(
b
)
:
"h"
(
a
));
return
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
float
a
,
float
b
,
float
c
)
{
return
a
*
b
+
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
float2
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
float
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
fma
(
float4
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
fma
(
float
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
float
a
,
Float4_
b
,
Float4_
c
)
{
Float4_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
float
a
,
Float8_
b
,
Float8_
c
)
{
Float8_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float2
add
(
__nv_bfloat162
a
,
float2
fb
)
{
float2
fa
=
bf1622float2
(
a
);
return
add
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
add
(
bf16_4_t
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
add
(
bf16_8_t
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
fma
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
fma
(
uint16_t
a
,
uint32_t
b
,
uint32_t
c
)
{
return
fma
(
h0_h0
(
a
),
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
fma
(
uint2
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
fma
(
uint16_t
a
,
uint2
b
,
uint2
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
fma
(
uint4
a
,
uint4
b
,
uint4
c
)
{
uint4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
fma
(
uint16_t
a
,
uint4
b
,
uint4
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
uint16_t
a
,
uint16_t
b
,
float
fc
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
+
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
uint32_t
a
,
uint32_t
b
,
float2
fc
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
uint16_t
a
,
uint32_t
b
,
float2
fc
)
{
return
fma
(
h0_h0
(
a
),
b
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
uint2
a
,
uint2
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
uint16_t
a
,
uint2
b
,
Float4_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
uint4
a
,
uint4
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
uint16_t
a
,
uint4
b
,
Float8_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
bf16hfma2
(
a
,
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
bf16hfma2
(
bf162bf162
(
a
),
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
bf16_4_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
bf16_8_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
float
fc
)
{
return
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
+
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
float2
fc
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
float2
fc
)
{
return
fma
(
bf162bf162
(
a
),
b
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
Float4_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
Float8_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
>
inline
__device__
Acc
mul
(
A
a
,
B
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
<
float
,
float
>
(
float
a
,
float
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
float
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
c
.
z
=
a
.
z
*
b
.
z
;
c
.
w
=
a
.
w
*
b
.
w
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float4
mul
(
float
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
c
.
z
=
a
*
b
.
z
;
c
.
w
=
a
*
b
.
w
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
float
a
,
Float8_
b
)
{
Float8_
c
;
c
.
x
=
make_float2
(
a
*
b
.
x
.
x
,
a
*
b
.
x
.
y
);
c
.
y
=
make_float2
(
a
*
b
.
y
.
x
,
a
*
b
.
y
.
y
);
c
.
z
=
make_float2
(
a
*
b
.
z
.
x
,
a
*
b
.
z
.
y
);
c
.
w
=
make_float2
(
a
*
b
.
w
.
x
,
a
*
b
.
w
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint16_t
mul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"mul.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint32_t
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint2
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint4
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
uint16_t
b
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
float
b
)
{
return
half_to_float
(
a
)
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
uint32_t
a
,
uint32_t
b
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
float2
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
uint2
a
,
uint2
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
uint4
a
,
uint4
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return
__hmul
(
a
,
b
);
#else
return
bf16hmul
(
a
,
b
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
bf16hmul2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_4_t
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_4_t
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_8_t
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_8_t
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
float
fa
=
(
float
)
a
;
float
fb
=
(
float
)
b
;
return
fa
*
fb
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
float
b
)
{
return
__bfloat162float
(
a
)
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
fc
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
sum
(
__nv_bfloat162
v
)
{
float2
vf
=
bf1622float2
(
v
);
return
vf
.
x
+
vf
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
bf16_4_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
bf16_8_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
)
+
sum
(
v
.
z
)
+
sum
(
v
.
w
);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint16_t
v
)
{
return
half_to_float
(
v
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint32_t
v
)
{
float2
tmp
=
half2_to_float2
(
v
);
return
tmp
.
x
+
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint2
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
return
sum
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint4
v
)
{
#if 1
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
c
=
add
(
c
,
v
.
z
);
c
=
add
(
c
,
v
.
w
);
#else
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
uint32_t
d
=
add
(
v
.
z
,
v
.
w
);
c
=
add
(
c
,
d
);
#endif
return
sum
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
Float4_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
Float8_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
+
v
.
z
.
x
+
v
.
z
.
y
+
v
.
w
.
x
+
v
.
w
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
T
,
T
,
T
>
(
a
,
b
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
A
,
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
A
,
T
,
T
>
(
a
,
b
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
void
zero
(
T
&
dst
)
{
constexpr
int
WORDS
=
sizeof
(
T
)
/
4
;
union
{
T
raw
;
uint32_t
words
[
WORDS
];
}
tmp
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
WORDS
;
++
ii
)
{
tmp
.
words
[
ii
]
=
0u
;
}
dst
=
tmp
.
raw
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
float
t_step
,
const
float
base
)
{
const
float
inv_freq
=
t_step
/
pow
(
base
,
zid
/
(
float
)
rot_embed_dim
);
return
{
cos
(
inv_freq
),
sin
(
inv_freq
)};
}
inline
__device__
float2
rotary_embedding_transform
(
const
float2
v
,
const
float2
coef
)
{
float2
rot_v
;
rot_v
.
x
=
coef
.
x
*
v
.
x
-
coef
.
y
*
v
.
y
;
rot_v
.
y
=
coef
.
x
*
v
.
y
+
coef
.
y
*
v
.
x
;
return
rot_v
;
}
inline
__device__
uint32_t
rotary_embedding_transform
(
const
uint32_t
v
,
const
float2
coef
)
{
float2
fv
=
half2_to_float2
(
v
);
float2
rot_fv
=
rotary_embedding_transform
(
fv
,
coef
);
return
float2_to_half2
(
rot_fv
);
}
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat162
rotary_embedding_transform
(
const
__nv_bfloat162
v
,
const
float2
coef
)
{
float2
fv
=
bf1622float2
(
v
);
float2
rot_fv
=
rotary_embedding_transform
(
fv
,
coef
);
return
__floats2bfloat162_rn
(
rot_fv
.
x
,
rot_fv
.
y
);
}
#endif
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
float
&
k
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
float2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
float4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
Float4_
&
k_
=
*
reinterpret_cast
<
Float4_
*>
(
&
k
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
k_
.
x
=
rotary_embedding_transform
(
k_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
k_
.
y
=
rotary_embedding_transform
(
k_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
uint32_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
uint2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
uint4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#ifdef ENABLE_BF16
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
bf16_4_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
bf16_8_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#endif // ENABLE_BF16
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
vec_from_smem_transpose
(
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
return
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint32_t
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
tmp
.
u16
[
0
]
=
smem
[
transpose_idx
];
tmp
.
u16
[
1
]
=
smem
[
smem_pitch
+
transpose_idx
];
vec
=
tmp
.
u32
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint2
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
union
{
uint2
u32x2
;
uint16_t
u16
[
4
];
}
tmp_3
;
tmp_3
.
u16
[
0
]
=
tmp_1
.
u16
[
0
];
tmp_3
.
u16
[
1
]
=
tmp_2
.
u16
[
0
];
tmp_3
.
u16
[
2
]
=
tmp_1
.
u16
[
1
];
tmp_3
.
u16
[
3
]
=
tmp_2
.
u16
[
1
];
vec
=
tmp_3
.
u32x2
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint4
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
uint16_t
u16
[
4
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
union
{
uint4
u32x4
;
uint16_t
u16
[
8
];
}
tmp_3
;
tmp_3
.
u16
[
0
]
=
tmp_1
.
u16
[
0
];
tmp_3
.
u16
[
1
]
=
tmp_2
.
u16
[
0
];
tmp_3
.
u16
[
2
]
=
tmp_1
.
u16
[
1
];
tmp_3
.
u16
[
3
]
=
tmp_2
.
u16
[
1
];
tmp_3
.
u16
[
4
]
=
tmp_1
.
u16
[
2
];
tmp_3
.
u16
[
5
]
=
tmp_2
.
u16
[
2
];
tmp_3
.
u16
[
6
]
=
tmp_1
.
u16
[
3
];
tmp_3
.
u16
[
7
]
=
tmp_2
.
u16
[
3
];
vec
=
tmp_3
.
u32x4
;
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
bf16_4_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
__nv_bfloat16
bf16
[
2
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
vec
.
x
=
__nv_bfloat162
{
tmp_1
.
bf16
[
0
],
tmp_2
.
bf16
[
0
]};
vec
.
y
=
__nv_bfloat162
{
tmp_1
.
bf16
[
1
],
tmp_2
.
bf16
[
1
]};
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
bf16_8_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
__nv_bfloat16
bf16
[
4
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
vec
.
x
=
__nv_bfloat162
{
tmp_1
.
bf16
[
0
],
tmp_2
.
bf16
[
0
]};
vec
.
y
=
__nv_bfloat162
{
tmp_1
.
bf16
[
1
],
tmp_2
.
bf16
[
1
]};
vec
.
z
=
__nv_bfloat162
{
tmp_1
.
bf16
[
2
],
tmp_2
.
bf16
[
2
]};
vec
.
w
=
__nv_bfloat162
{
tmp_1
.
bf16
[
3
],
tmp_2
.
bf16
[
3
]};
}
#endif // ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float4
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
z
=
smem
[
transpose_idx
+
1
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
vec
.
w
=
smem
[
smem_pitch
+
transpose_idx
+
1
];
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint32_t
&
vec
,
half
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u16
[
0
]
=
smem
[
transpose_idx
];
tmp
.
u16
[
1
]
=
smem
[
smem_pitch
+
transpose_idx
];
vec
=
tmp
.
u32
;
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
__nv_bfloat162
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
}
#endif
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float2
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
}
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
write_smem_transpose
(
const
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
return
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint4
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
uint16_t
u16
[
4
];
}
tmp_1
,
tmp_2
;
union
{
uint4
u32x4
;
uint16_t
u16
[
8
];
}
tmp_3
;
tmp_3
.
u32x4
=
vec
;
tmp_1
.
u16
[
0
]
=
tmp_3
.
u16
[
0
];
tmp_2
.
u16
[
0
]
=
tmp_3
.
u16
[
1
];
tmp_1
.
u16
[
1
]
=
tmp_3
.
u16
[
2
];
tmp_2
.
u16
[
1
]
=
tmp_3
.
u16
[
3
];
tmp_1
.
u16
[
2
]
=
tmp_3
.
u16
[
4
];
tmp_2
.
u16
[
2
]
=
tmp_3
.
u16
[
5
];
tmp_1
.
u16
[
3
]
=
tmp_3
.
u16
[
6
];
tmp_2
.
u16
[
3
]
=
tmp_3
.
u16
[
7
];
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
])
=
tmp_1
.
u64
;
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
])
=
tmp_2
.
u64
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint2
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp_1
,
tmp_2
;
union
{
uint2
u32x2
;
uint16_t
u16
[
4
];
}
tmp_3
;
tmp_3
.
u32x2
=
vec
;
tmp_1
.
u16
[
0
]
=
tmp_3
.
u16
[
0
];
tmp_2
.
u16
[
0
]
=
tmp_3
.
u16
[
1
];
tmp_1
.
u16
[
1
]
=
tmp_3
.
u16
[
2
];
tmp_2
.
u16
[
1
]
=
tmp_3
.
u16
[
3
];
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
])
=
tmp_1
.
u32
;
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
])
=
tmp_2
.
u32
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint32_t
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
tmp
.
u32
=
vec
;
smem
[
transpose_idx
]
=
tmp
.
u16
[
0
];
smem
[
smem_pitch
+
transpose_idx
]
=
tmp
.
u16
[
1
];
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float4
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
transpose_idx
+
1
]
=
vec
.
z
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
smem
[
smem_pitch
+
transpose_idx
+
1
]
=
vec
.
w
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint32_t
&
vec
,
half
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u32
=
vec
;
smem
[
transpose_idx
]
=
tmp
.
u16
[
0
];
smem
[
smem_pitch
+
transpose_idx
]
=
tmp
.
u16
[
1
];
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
__nv_bfloat162
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
bf16_4_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
write_smem_transpose
(
reinterpret_cast
<
const
uint2
&>
(
vec
),
reinterpret_cast
<
uint16_t
*>
(
smem
),
transpose_idx
,
smem_pitch
);
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
bf16_8_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
write_smem_transpose
(
reinterpret_cast
<
const
uint4
&>
(
vec
),
reinterpret_cast
<
uint16_t
*>
(
smem
),
transpose_idx
,
smem_pitch
);
}
#endif
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float2
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
}
}
// namespace mmha
awq_cuda/attention/ft_attention.cpp
0 → 100644
View file @
ef6b60e2
// Adapted from NVIDIA/FasterTransformer and FlashAttention
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ft_attention.h"
#include "decoder_masked_multihead_attention.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
if (TYPE == at::ScalarType::Half) { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::BFloat16) { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::Float) { \
using scalar_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
template
<
typename
T
>
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
void
cross_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
struct
SATypeConverter
{
using
Type
=
T
;
};
template
<
>
struct
SATypeConverter
<
at
::
Half
>
{
using
Type
=
uint16_t
;
};
template
<
>
struct
SATypeConverter
<
at
::
BFloat16
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
typename
T
>
void
set_params
(
Masked_multihead_attention_params
<
T
>
&
params
,
const
size_t
batch_size
,
const
size_t
nheads
,
const
size_t
nheads_kv
,
const
size_t
memory_max_seqlen
,
const
size_t
headdim
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
T
*
q_ptr
,
T
*
k_ptr
,
T
*
v_ptr
,
T
*
k_cache_ptr
,
T
*
v_cache_ptr
,
int
*
length_per_sample
,
float
*
alibi_slopes_ptr
,
T
*
out_ptr
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
q
=
q_ptr
;
params
.
k
=
k_ptr
;
params
.
v
=
v_ptr
;
params
.
q_bias
=
nullptr
;
params
.
k_bias
=
nullptr
;
params
.
v_bias
=
nullptr
;
params
.
k_cache
=
k_cache_ptr
;
params
.
v_cache
=
v_cache_ptr
;
params
.
linear_bias_slopes
=
alibi_slopes_ptr
;
params
.
out
=
out_ptr
;
params
.
cache_indir
=
nullptr
;
params
.
stride
=
qkv_batch_stride
;
params
.
batch_size
=
batch_size
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
num_heads
=
nheads
;
params
.
num_kv_heads
=
nheads_kv
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_base
=
rotary_base
;
params
.
neox_rotary_style
=
neox_rotary_style
;
params
.
timestep
=
timestep
;
params
.
inv_sqrt_dh
=
1.
f
/
sqrt
(
float
(
headdim
));
params
.
total_padding_tokens
=
nullptr
;
params
.
masked_tokens
=
nullptr
;
params
.
prefix_prompt_lengths
=
nullptr
;
params
.
max_prefix_prompt_length
=
0
;
params
.
relative_attention_bias
=
nullptr
;
params
.
relative_attention_bias_stride
=
0
;
params
.
cross_attention_out
=
nullptr
;
params
.
max_decoder_seq_len
=
0
;
params
.
is_return_cross_attentions
=
false
;
params
.
finished
=
nullptr
;
params
.
memory_length_per_sample
=
nullptr
;
params
.
length_per_sample
=
length_per_sample
;
}
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
// neox_rotary_style = not interleaved
const
bool
neox_rotary_style
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
int
batch_size
=
v_cache
.
size
(
0
);
int
nheads
=
q
.
size
(
1
);
int
nheads_kv
=
v_cache
.
size
(
1
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
headdim
=
v_cache
.
size
(
3
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads_kv
,
memory_max_seqlen
,
headdim
);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int
packsize
=
k_cache
.
dtype
()
==
torch
::
kFloat32
?
4
:
8
;
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads_kv
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
TORCH_CHECK
(
q
.
stride
(
2
)
==
1
&&
q
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
k
.
stride
(
2
)
==
1
&&
k
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
v
.
stride
(
2
)
==
1
&&
v
.
stride
(
1
)
==
headdim
);
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
if
(
length_per_sample_
.
has_value
())
{
auto
length_per_sample
=
length_per_sample_
.
value
();
CHECK_DEVICE
(
length_per_sample
);
CHECK_SHAPE
(
length_per_sample
,
batch_size
);
CHECK_CONTIGUOUS
(
length_per_sample
);
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
}
if
(
alibi_slopes_
.
has_value
())
{
auto
alibi_slopes
=
alibi_slopes_
.
value
();
CHECK_DEVICE
(
alibi_slopes
);
CHECK_SHAPE
(
alibi_slopes
,
nheads
);
CHECK_CONTIGUOUS
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
nheads_kv
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k_cache
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v_cache
.
data_ptr
()),
length_per_sample_
.
has_value
()
?
length_per_sample_
.
value
().
data_ptr
<
int
>
()
:
nullptr
,
alibi_slopes_
.
has_value
()
?
alibi_slopes_
.
value
().
data_ptr
<
float
>
()
:
nullptr
,
reinterpret_cast
<
DataType
*>
(
out
.
data_ptr
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
masked_multihead_attention
(
params
,
stream
);
});
return
out
;
}
\ No newline at end of file
awq_cuda/attention/ft_attention.h
0 → 100644
View file @
ef6b60e2
#pragma once
#include <torch/extension.h>
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
=
0
,
const
float
rotary_base
=
10000
.
0
f
,
const
bool
neox_rotary_style
=
true
);
\ No newline at end of file
awq_cuda/pybind.cpp
View file @
ef6b60e2
#include <pybind11/pybind11.h>
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
#include "layernorm/layernorm.h"
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply rotary embedding to query and key"
);
m
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply rotary embedding to query and key"
);
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
}
\ No newline at end of file
awq_cuda/quantization/gemv_cuda.cu
0 → 100644
View file @
ef6b60e2
// Inspired by https://github.com/ankan-ban/llama_cu_awq
/*
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
}
*/
#include <cuda_fp16.h>
#include <stdio.h>
#include <torch/extension.h>
#include "gemv_cuda.h"
#define VECTORIZE_FACTOR 8
#define Q_VECTORIZE_FACTOR 8
#define PACK_FACTOR 8
#define WARP_SIZE 32
// Reduce sum within the warp using the tree reduction algorithm.
__device__
__forceinline__
float
warp_reduce_sum
(
float
sum
)
{
#pragma unroll
for
(
int
i
=
4
;
i
>=
0
;
i
--
){
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
1
<<
i
);
}
/*
// Equivalent to the following tree reduction implementation:
sum += __shfl_down_sync(0xffffffff, sum, 16);
sum += __shfl_down_sync(0xffffffff, sum, 8);
sum += __shfl_down_sync(0xffffffff, sum, 4);
sum += __shfl_down_sync(0xffffffff, sum, 2);
sum += __shfl_down_sync(0xffffffff, sum, 1);
*/
return
sum
;
}
__device__
__forceinline__
int
make_divisible
(
int
c
,
int
divisor
){
return
(
c
+
divisor
-
1
)
/
divisor
;
}
/*
Computes GEMV (group_size = 64).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g64
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
64
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
// This is essentially zeros_w.
const
int
num_groups_packed
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
;
// consistent with input shape
const
int
sf_w
=
make_divisible
(
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
),
2
)
*
2
*
PACK_FACTOR
;
// if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w, sf_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
/
2
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint64_t
packed_zeros
=
*
reinterpret_cast
<
const
uint64_t
*>
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
*
2
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g64: two threads -> 64 numbers -> 1 group; 1 warp = 16 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
16
+
(
threadIdx
.
x
/
2
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
2
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (group_size = 128).
Args:
inputs: vector of shape [batch_size, IC];
weight: matrix of shape [OC, IC / 8];
output: vector of shape [OC];
zeros: matrix of shape [OC, IC / group_size / 8];
scaling_factors: matrix of shape [OC, IC / group_size];
Notes:
One cannot infer group_size from the shape of scaling factors.
the second dimension is rounded up to a multiple of PACK_FACTOR.
*/
__global__
void
gemv_kernel_g128
(
const
float4
*
_inputs
,
const
uint32_t
*
weight
,
const
uint32_t
*
zeros
,
const
half
*
scaling_factors
,
half
*
_outputs
,
const
int
IC
,
const
int
OC
){
const
int
group_size
=
128
;
float
psum
=
0
;
const
int
batch_idx
=
blockIdx
.
z
;
const
int
oc_idx
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
const
float4
*
inputs
=
_inputs
+
batch_idx
*
IC
/
PACK_FACTOR
;
half
*
outputs
=
_outputs
+
batch_idx
*
OC
;
const
int
num_groups_packed
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
const
int
weight_w
=
IC
/
PACK_FACTOR
;
// TODO (Haotian): zeros_w is incorrect, after fixing we got misaligned address
const
int
zeros_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
);
// consistent with input shape
const
int
sf_w
=
make_divisible
(
IC
/
group_size
,
PACK_FACTOR
)
*
PACK_FACTOR
;
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0) printf("%d %d %d %d\n", IC, group_size, PACK_FACTOR, zeros_w);
// tile size: 4 OC x 1024 IC per iter
for
(
int
packed_group_idx
=
0
;
packed_group_idx
<
num_groups_packed
;
packed_group_idx
++
){
// 1024 numbers in one iteration across warp. Need 1024 / group_size zeros.
uint32_t
packed_zeros
=
*
(
zeros
+
oc_idx
*
zeros_w
+
packed_group_idx
);
uint32_t
packed_weights
[
4
];
// use float4 to load weights, each thread load 32 int4 numbers (1 x float4)
*
((
float4
*
)(
packed_weights
))
=
*
((
float4
*
)(
weight
+
oc_idx
*
weight_w
+
packed_group_idx
*
(
WARP_SIZE
*
4
)
+
threadIdx
.
x
*
4
));
// load scaling factors
// g128: four threads -> 128 numbers -> 1 group; 1 warp = 8 groups.
float
scaling_factor
=
__half2float
(
scaling_factors
[
oc_idx
*
sf_w
+
packed_group_idx
*
8
+
(
threadIdx
.
x
/
4
)]);
float
current_zeros
=
(
float
)((
packed_zeros
>>
(
threadIdx
.
x
/
4
*
4
))
&
0xF
);
int
inputs_ptr_delta
=
packed_group_idx
*
WARP_SIZE
*
4
+
threadIdx
.
x
*
4
;
const
float4
*
inputs_ptr
=
inputs
+
inputs_ptr_delta
;
// multiply 32 weights with 32 inputs
#pragma unroll
for
(
int
ic_0
=
0
;
ic_0
<
4
;
ic_0
++
){
// iterate over different uint32_t packed_weights in this loop
uint32_t
current_packed_weight
=
packed_weights
[
ic_0
];
half
packed_inputs
[
PACK_FACTOR
];
// each thread load 8 inputs, starting index is packed_group_idx * 128 * 8 (because each iter loads 128*8)
if
(
inputs_ptr_delta
+
ic_0
<
IC
/
PACK_FACTOR
)
{
*
((
float4
*
)
packed_inputs
)
=
*
(
inputs_ptr
+
ic_0
);
#pragma unroll
for
(
int
ic_1
=
0
;
ic_1
<
PACK_FACTOR
;
ic_1
++
){
// iterate over 8 numbers packed within each uint32_t number
float
current_single_weight_fp
=
(
float
)(
current_packed_weight
&
0xF
);
float
dequantized_weight
=
scaling_factor
*
(
current_single_weight_fp
-
current_zeros
);
//if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0 && threadIdx.y == 0 && ic_0 == 0 && ic_1 == 0 && packed_group_idx == 0) printf("%f %f %f %f %X %X\n", dequantized_weight, current_single_weight_fp, scaling_factor, current_zeros, current_packed_weight, packed_zeros);
psum
+=
dequantized_weight
*
__half2float
(
packed_inputs
[
ic_1
]);
current_packed_weight
=
current_packed_weight
>>
4
;
}
}
}
}
psum
=
warp_reduce_sum
(
psum
);
if
(
threadIdx
.
x
==
0
)
{
outputs
[
oc_idx
]
=
__float2half
(
psum
);
}
}
/*
Computes GEMV (PyTorch interface).
Args:
_in_feats: tensor of shape [B, IC];
_kernel: int tensor of shape [OC, IC // 8];
_zeros: int tensor of shape [OC, IC // G // 8];
_scaling_factors: tensor of shape [OC, IC // G];
blockDim_x: size of thread block, dimension x, where blockDim_x * workload_per_thread = IC;
blockDim_y: size of thread block, dimension y, where blockDim_y * gridDim_y = OC;
Returns:
out_feats: tensor of shape [B, OC];
*/
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
// int kernel_volume = _out_in_map.size(1);
auto
in_feats
=
reinterpret_cast
<
float4
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
uint32_t
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
zeros
=
reinterpret_cast
<
uint32_t
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
// auto out_in_map = _out_in_map.data_ptr<int>();
auto
options
=
torch
::
TensorOptions
().
dtype
(
_in_feats
.
dtype
()).
device
(
_in_feats
.
device
());
// kernel is [OC, IC]
at
::
Tensor
_out_feats
=
torch
::
empty
({
num_in_feats
,
_kernel
.
size
(
0
)},
options
);
int
num_out_feats
=
_out_feats
.
size
(
-
2
);
int
num_out_channels
=
_out_feats
.
size
(
-
1
);
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
int
blockDim_z
=
num_out_feats
;
dim3
num_blocks
(
1
,
num_out_channels
/
4
,
num_out_feats
);
dim3
num_threads
(
32
,
4
);
if
(
group_size
==
64
)
{
gemv_kernel_g64
<<<
num_blocks
,
num_threads
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
else
if
(
group_size
==
128
)
{
gemv_kernel_g128
<<<
num_blocks
,
num_threads
>>>
(
// pointers
in_feats
,
kernel
,
zeros
,
scaling_factors
,
out_feats
,
// constants
num_in_channels
,
num_out_channels
);
}
return
_out_feats
;
;}
awq_cuda/quantization/gemv_cuda.h
0 → 100644
View file @
ef6b60e2
#pragma once
#include <torch/extension.h>
torch
::
Tensor
gemv_forward_cuda
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
group_size
);
setup.py
View file @
ef6b60e2
...
@@ -47,11 +47,24 @@ requirements = [
...
@@ -47,11 +47,24 @@ requirements = [
"torchvision"
"torchvision"
]
]
include_dirs
=
[]
def
get_include_dirs
():
include_dirs
=
[]
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
conda_cuda_include_dir
=
os
.
path
.
join
(
get_python_lib
(),
"nvidia/cuda_runtime/include"
)
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
if
os
.
path
.
isdir
(
conda_cuda_include_dir
):
include_dirs
.
append
(
conda_cuda_include_dir
)
include_dirs
.
append
(
conda_cuda_include_dir
)
this_dir
=
os
.
path
.
dirname
(
os
.
path
.
abspath
(
__file__
))
include_dirs
.
append
(
this_dir
)
return
include_dirs
def
get_generator_flag
():
generator_flag
=
[]
torch_dir
=
torch
.
__path__
[
0
]
if
os
.
path
.
exists
(
os
.
path
.
join
(
torch_dir
,
"include"
,
"ATen"
,
"CUDAGeneratorImpl.h"
)):
generator_flag
=
[
"-DOLD_GENERATOR_PATH"
]
return
generator_flag
def
check_dependencies
():
def
check_dependencies
():
if
CUDA_HOME
is
None
:
if
CUDA_HOME
is
None
:
...
@@ -77,6 +90,8 @@ def get_compute_capabilities():
...
@@ -77,6 +90,8 @@ def get_compute_capabilities():
return
capability_flags
return
capability_flags
check_dependencies
()
check_dependencies
()
include_dirs
=
get_include_dirs
()
generator_flags
=
get_generator_flag
()
arch_flags
=
get_compute_capabilities
()
arch_flags
=
get_compute_capabilities
()
if
os
.
name
==
"nt"
:
if
os
.
name
==
"nt"
:
...
@@ -86,8 +101,21 @@ if os.name == "nt":
...
@@ -86,8 +101,21 @@ if os.name == "nt":
}
}
else
:
else
:
extra_compile_args
=
{
extra_compile_args
=
{
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
],
"cxx"
:
[
"-g"
,
"-O3"
,
"-fopenmp"
,
"-lgomp"
,
"-std=c++17"
,
"-DENABLE_BF16"
],
"nvcc"
:
[
"-O3"
,
"-std=c++17"
]
+
arch_flags
"nvcc"
:
[
"-O3"
,
"-std=c++17"
,
"-DENABLE_BF16"
,
"-U__CUDA_NO_HALF_OPERATORS__"
,
"-U__CUDA_NO_HALF_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT16_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT16_CONVERSIONS__"
,
"-U__CUDA_NO_BFLOAT162_OPERATORS__"
,
"-U__CUDA_NO_BFLOAT162_CONVERSIONS__"
,
"--expt-relaxed-constexpr"
,
"--expt-extended-lambda"
,
"--use_fast_math"
,
]
+
arch_flags
+
generator_flags
}
}
extensions
=
[
extensions
=
[
...
@@ -97,7 +125,10 @@ extensions = [
...
@@ -97,7 +125,10 @@ extensions = [
"awq_cuda/pybind.cpp"
,
"awq_cuda/pybind.cpp"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/quantization/gemm_cuda_gen.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/layernorm/layernorm.cu"
,
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
"awq_cuda/position_embedding/pos_encoding_kernels.cu"
,
"awq_cuda/quantization/gemv_cuda.cu"
,
"awq_cuda/attention/ft_attention.cpp"
,
"awq_cuda/attention/decoder_masked_multihead_attention.cu"
],
extra_compile_args
=
extra_compile_args
],
extra_compile_args
=
extra_compile_args
)
)
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment