Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
2edb3f6f
Unverified
Commit
2edb3f6f
authored
Dec 28, 2023
by
Casper
Committed by
GitHub
Dec 28, 2023
Browse files
AWQ: Separate the AWQ kernels to separate repository (#279)
Co-authored-by:
Casper Hansen
<
casperbh96@gmail.com
>
parent
3f10cf1d
Changes
26
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
12 additions
and
4541 deletions
+12
-4541
README.md
README.md
+1
-2
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+2
-2
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+3
-3
awq/modules/fused/norm.py
awq/modules/fused/norm.py
+2
-2
awq/modules/linear.py
awq/modules/linear.py
+4
-4
awq_cuda/attention/cuda_bf16_fallbacks.cuh
awq_cuda/attention/cuda_bf16_fallbacks.cuh
+0
-257
awq_cuda/attention/cuda_bf16_wrapper.h
awq_cuda/attention/cuda_bf16_wrapper.h
+0
-23
awq_cuda/attention/decoder_masked_multihead_attention.cu
awq_cuda/attention/decoder_masked_multihead_attention.cu
+0
-152
awq_cuda/attention/decoder_masked_multihead_attention.h
awq_cuda/attention/decoder_masked_multihead_attention.h
+0
-184
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
...attention/decoder_masked_multihead_attention_template.hpp
+0
-1608
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
...cuda/attention/decoder_masked_multihead_attention_utils.h
+0
-1786
awq_cuda/attention/ft_attention.cpp
awq_cuda/attention/ft_attention.cpp
+0
-182
awq_cuda/attention/ft_attention.h
awq_cuda/attention/ft_attention.h
+0
-15
awq_cuda/layernorm/layernorm.cu
awq_cuda/layernorm/layernorm.cu
+0
-113
awq_cuda/layernorm/layernorm.h
awq_cuda/layernorm/layernorm.h
+0
-3
awq_cuda/layernorm/reduction.cuh
awq_cuda/layernorm/reduction.cuh
+0
-82
awq_cuda/position_embedding/pos_encoding.h
awq_cuda/position_embedding/pos_encoding.h
+0
-9
awq_cuda/position_embedding/pos_encoding_kernels.cu
awq_cuda/position_embedding/pos_encoding_kernels.cu
+0
-88
awq_cuda/pybind_awq.cpp
awq_cuda/pybind_awq.cpp
+0
-15
awq_cuda/pybind_ft.cpp
awq_cuda/pybind_ft.cpp
+0
-11
No files found.
README.md
View file @
2edb3f6f
...
...
@@ -32,6 +32,7 @@ AutoAWQ is an easy-to-use package for 4-bit quantized models. AutoAWQ speeds up
-
Your GPU(s) must be of Compute Capability 7.5. Turing and later architectures are supported.
-
Your CUDA version must be CUDA 11.8 or later.
-
Requires installing
[
AutoAWQ kernels
](
https://github.com/casper-hansen/AutoAWQ_kernels
)
.
### Install from PyPi
...
...
@@ -49,8 +50,6 @@ pip install https://github.com/casper-hansen/AutoAWQ/releases/download/v0.1.6/au
### Build from source
Build time can take 10-20 minutes. Download your model while you install AutoAWQ.
```
git clone https://github.com/casper-hansen/AutoAWQ
cd AutoAWQ
...
...
awq/modules/fused/attn.py
View file @
2edb3f6f
...
...
@@ -8,7 +8,7 @@ from awq.utils.fused_utils import get_attention_shapes
try
:
import
ft_inference_engine
import
awq_ft_ext
FT_INSTALLED
=
True
except
:
FT_INSTALLED
=
False
...
...
@@ -214,7 +214,7 @@ class QuantAttentionFused(nn.Module):
xv
=
xv
.
view
((
bsz
,)
+
self
.
attention_shapes
[
"single_xv_view"
])
alibi_slopes
=
self
.
alibi
.
slopes
if
self
.
alibi
is
not
None
else
None
attention_weight
=
ft_inference_engine
.
single_query_attention
(
attention_weight
=
awq_ft_ext
.
single_query_attention
(
xq
,
# query
xk
,
# key
xv
,
# value
...
...
awq/modules/fused/mlp.py
View file @
2edb3f6f
import
torch.nn
as
nn
import
awq_
inference_engine
import
awq_
ext
import
torch.nn.functional
as
F
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
...
...
@@ -28,10 +28,10 @@ class QuantFusedMLP(nn.Module):
self
.
down_proj
=
down_proj
if
isinstance
(
down_proj
,
WQLinear_GEMV
):
self
.
linear
=
awq_
inference_engine
.
gemv_forward_cuda
self
.
linear
=
awq_
ext
.
gemv_forward_cuda
self
.
group_size
=
down_proj
.
group_size
else
:
self
.
linear
=
awq_
inference_engine
.
gemm_forward_cuda
self
.
linear
=
awq_
ext
.
gemm_forward_cuda
self
.
group_size
=
8
self
.
activation
=
activation
...
...
awq/modules/fused/norm.py
View file @
2edb3f6f
import
torch
from
torch
import
nn
import
awq_
inference_engine
import
awq_
ext
class
FasterTransformerRMSNorm
(
nn
.
Module
):
def
__init__
(
self
,
weight
,
eps
=
1e-6
):
...
...
@@ -10,5 +10,5 @@ class FasterTransformerRMSNorm(nn.Module):
def
forward
(
self
,
x
):
output
=
torch
.
empty_like
(
x
)
awq_
inference_engine
.
layernorm_forward_cuda
(
x
,
self
.
weight
,
output
,
self
.
variance_epsilon
)
awq_
ext
.
layernorm_forward_cuda
(
x
,
self
.
weight
,
output
,
self
.
variance_epsilon
)
return
output
awq/modules/linear.py
View file @
2edb3f6f
import
math
import
torch
import
torch.nn
as
nn
import
awq_
inference_engine
# with CUDA kernels
import
awq_
ext
# with CUDA kernels
def
make_divisible
(
c
,
divisor
):
...
...
@@ -102,7 +102,7 @@ class WQLinear_GEMM(nn.Module):
if
input_dtype
!=
torch
.
float16
:
x
=
x
.
half
()
out
=
awq_
inference_engine
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
)
out
=
awq_
ext
.
gemm_forward_cuda
(
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
]),
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
8
)
if
input_dtype
!=
torch
.
float16
:
out
=
out
.
to
(
dtype
=
input_dtype
)
...
...
@@ -210,9 +210,9 @@ class WQLinear_GEMV(nn.Module):
inputs
=
inputs
.
half
()
if
inputs
.
shape
[
0
]
>
8
:
out
=
awq_
inference_engine
.
gemmv2_forward_cuda
(
inputs
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
self
.
group_size
,
self
.
split_k_iters
)
out
=
awq_
ext
.
gemmv2_forward_cuda
(
inputs
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
self
.
group_size
,
self
.
split_k_iters
)
else
:
out
=
awq_
inference_engine
.
gemv_forward_cuda
(
inputs
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
self
.
group_size
)
out
=
awq_
ext
.
gemv_forward_cuda
(
inputs
,
self
.
qweight
,
self
.
scales
,
self
.
qzeros
,
self
.
group_size
)
if
input_dtype
!=
torch
.
float16
:
out
=
out
.
to
(
dtype
=
input_dtype
)
...
...
awq_cuda/attention/cuda_bf16_fallbacks.cuh
deleted
100644 → 0
View file @
3f10cf1d
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_fallbacks.cuh
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
namespace
fastertransformer
{
#ifdef ENABLE_BF16
inline
__device__
float2
bf1622float2
(
const
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2
f_val
;
f_val
.
x
=
__low2float
(
val
);
f_val
.
y
=
__high2float
(
val
);
return
f_val
;
#else
return
__bfloat1622float2
(
val
);
#endif
}
inline
__device__
int16_t
bf1622int16
(
__nv_bfloat162
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float2
f_val
;
f_val
.
x
=
max
(
min
(
__low2float
(
val
),
127.
f
),
-
128.
f
);
f_val
.
y
=
max
(
min
(
__high2float
(
val
),
127.
f
),
-
128.
f
);
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
int8
[
0
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
f_val
.
x
));
int8
[
1
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
f_val
.
y
));
return
int16
;
#else
val
=
__hmin2
(
val
,
make_bfloat162
(
127.
,
127.
));
val
=
__hmax2
(
val
,
make_bfloat162
(
-
128.
,
-
128.
));
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
int8
[
0
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
val
.
x
));
int8
[
1
]
=
static_cast
<
int8_t
>
(
static_cast
<
short
>
(
val
.
y
));
return
int16
;
#endif
}
inline
__device__
__nv_bfloat162
float22bf162
(
const
float2
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__floats2bfloat162_rn
(
val
.
x
,
val
.
y
);
#else
return
__float22bfloat162_rn
(
val
);
#endif
}
inline
__device__
__nv_bfloat162
bf162bf162
(
const
__nv_bfloat16
val
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
__nv_bfloat162
val2
;
val2
.
x
=
val
;
val2
.
y
=
val
;
return
val2
;
#else
return
__bfloat162bfloat162
(
val
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hadd2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
+
fyl
,
fxh
+
fyh
);
#else
return
__hadd2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hadd
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
+
__bfloat162float
(
y
)
);
#else
return
__hadd
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hsub2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
-
fyl
,
fxh
-
fyh
);
#else
return
__hsub2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hsub
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
-
__bfloat162float
(
y
)
);
#else
return
__hsub
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hmul2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
return
__floats2bfloat162_rn
(
fxl
*
fyl
,
fxh
*
fyh
);
#else
return
__hmul2
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hmul
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
*
__bfloat162float
(
y
)
);
#else
return
__hmul
(
x
,
y
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hfma2
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
,
const
__nv_bfloat162
z
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
,
fyl
,
fyh
,
fzl
,
fzh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);
fyl
=
__low2float
(
y
);
fyh
=
__high2float
(
y
);
fzl
=
__low2float
(
z
);
fzh
=
__high2float
(
z
);
return
__floats2bfloat162_rn
(
fxl
*
fyl
+
fzl
,
fxh
*
fyh
+
fzh
);
#else
return
__hfma2
(
x
,
y
,
z
);
#endif
}
inline
__device__
__nv_bfloat16
bf16hfma
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
,
const
__nv_bfloat16
z
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
x
)
*
__bfloat162float
(
y
)
+
__bfloat162float
(
z
));
#else
return
__hfma
(
x
,
y
,
z
);
#endif
}
inline
__device__
__nv_bfloat162
bf16exp2
(
const
__nv_bfloat162
x
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fxl
,
fxh
;
fxl
=
__low2float
(
x
);
fxh
=
__high2float
(
x
);;
return
__floats2bfloat162_rn
(
expf
(
fxl
),
expf
(
fxh
));
#else
return
h2exp
(
x
);
#endif
}
#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)
inline
__device__
__nv_bfloat162
operator
*
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
return
bf16hmul2
(
x
,
y
);
};
inline
__device__
__nv_bfloat162
operator
+
(
const
__nv_bfloat162
x
,
const
__nv_bfloat162
y
)
{
return
bf16hadd2
(
x
,
y
);
};
inline
__device__
__nv_bfloat162
make_bfloat162
(
const
__nv_bfloat16
x
,
const
__nv_bfloat16
y
)
{
__nv_bfloat162
t
;
t
.
x
=
x
;
t
.
y
=
y
;
return
t
;
}
#endif
inline
__device__
__nv_bfloat16
bf16hadd
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
+
__bfloat162float
(
b
)
+
__bfloat162float
(
c
));
#else
return
a
+
b
+
c
;
#endif
}
inline
__device__
__nv_bfloat16
bf16hadd
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
,
__nv_bfloat16
d
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
+
__bfloat162float
(
b
)
+
__bfloat162float
(
c
)
+
__bfloat162float
(
d
));
#else
return
(
__nv_bfloat16
)((
float
)
a
+
(
float
)
b
+
(
float
)
c
+
(
float
)
d
);
#endif
}
inline
__device__
__nv_bfloat162
bf16hadd2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
return
__floats2bfloat162_rn
(
fal
+
fbl
+
fcl
,
fah
+
fbh
+
fch
);
#else
return
a
+
b
+
c
;
#endif
}
inline
__device__
__nv_bfloat16
bf16hmul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
__nv_bfloat16
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
return
__float2bfloat16
(
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
*
__bfloat162float
(
c
));
#else
return
a
*
b
*
c
;
#endif
}
inline
__device__
__nv_bfloat162
bf16hmul2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
return
__floats2bfloat162_rn
(
fal
*
fbl
*
fcl
,
fah
*
fbh
*
fch
);
#else
return
a
*
b
*
c
;
#endif
}
inline
__device__
__nv_bfloat162
bf16hfma2
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
,
__nv_bfloat162
d
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 800
float
fal
,
fah
,
fbl
,
fbh
,
fcl
,
fch
,
fdl
,
fdh
;
fal
=
__low2float
(
a
);
fah
=
__high2float
(
a
);
fbl
=
__low2float
(
b
);
fbh
=
__high2float
(
b
);
fcl
=
__low2float
(
c
);
fch
=
__high2float
(
c
);
fdl
=
__low2float
(
d
);
fdh
=
__high2float
(
d
);
return
__floats2bfloat162_rn
(
fal
*
fbl
*
fcl
+
fdl
,
fah
*
fbh
*
fch
+
fdh
);
#else
return
a
*
b
*
c
+
d
;
#endif
}
#endif // ENABLE_BF16
}
// namespace fastertransformer
awq_cuda/attention/cuda_bf16_wrapper.h
deleted
100644 → 0
View file @
3f10cf1d
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/utils/cuda_bf16_wrapper.h
/*
* Copyright (c) 2019-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
awq_cuda/attention/decoder_masked_multihead_attention.cu
deleted
100644 → 0
View file @
3f10cf1d
// Adapted from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include <assert.h>
#include <float.h>
#include <type_traits>
#include "decoder_masked_multihead_attention_template.hpp"
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, DO_CROSS_ATTENTION, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T, DO_CROSS_ATTENTION>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
auto kernel = mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, \
THDS_PER_BLOCK, DO_CROSS_ATTENTION>; \
cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_sz); \
dim3 grid(params.num_heads, params.batch_size); \
kernel<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
// !!! Specialize the launcher for Cross attention
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
typename
KERNEL_PARAMS_TYPE
>
void
mmha_launch_kernel
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
)
{
constexpr
int
THREADS_PER_VALUE
=
Dh_MAX
*
sizeof
(
T
)
/
16
;
constexpr
bool
DO_CROSS_ATTENTION
=
std
::
is_same
<
KERNEL_PARAMS_TYPE
,
Cross_multihead_attention_params
<
T
>>::
value
;
int
tlength
=
(
DO_CROSS_ATTENTION
)
?
params
.
memory_max_len
:
params
.
timestep
;
// printf("tlength, CROSS_ATTENTION = %d, %d\n", tlength, DO_CROSS_ATTENTION);
if
(
tlength
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
DO_CROSS_ATTENTION
,
stream
);
}
else
if
(
tlength
<
2048
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
DO_CROSS_ATTENTION
,
stream
);
}
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
DO_CROSS_ATTENTION
,
stream
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#undef MMHA_LAUNCH_KERNEL
template
<
typename
T
,
typename
KERNEL_PARAMS_TYPE
>
void
multihead_attention_
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
)
{
switch
(
params
.
hidden_size_per_head
)
{
case
32
:
mmha_launch_kernel
<
T
,
32
,
32
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
48
:
mmha_launch_kernel
<
T
,
48
,
64
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
64
:
mmha_launch_kernel
<
T
,
64
,
64
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
80
:
mmha_launch_kernel
<
T
,
80
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
96
:
mmha_launch_kernel
<
T
,
96
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
112
:
mmha_launch_kernel
<
T
,
112
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
128
:
mmha_launch_kernel
<
T
,
128
,
128
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
160
:
mmha_launch_kernel
<
T
,
160
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
192
:
mmha_launch_kernel
<
T
,
192
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
224
:
mmha_launch_kernel
<
T
,
224
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
case
256
:
mmha_launch_kernel
<
T
,
256
,
256
,
KERNEL_PARAMS_TYPE
>
(
params
,
stream
);
break
;
default:
assert
(
false
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
float
,
Masked_multihead_attention_params
<
float
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
uint16_t
,
Masked_multihead_attention_params
<
uint16_t
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
__nv_bfloat16
,
Masked_multihead_attention_params
<
__nv_bfloat16
>>
(
params
,
stream
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
float
,
Cross_multihead_attention_params
<
float
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
uint16_t
,
Cross_multihead_attention_params
<
uint16_t
>>
(
params
,
stream
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
)
{
multihead_attention_
<
__nv_bfloat16
,
Cross_multihead_attention_params
<
__nv_bfloat16
>>
(
params
,
stream
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention.h
deleted
100644 → 0
View file @
3f10cf1d
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include <cuda_fp16.h>
#include <cuda_runtime_api.h>
#include <stdint.h>
#include <stdio.h>
#include <stdlib.h>
////////////////////////////////////////////////////////////////////////////////////////////////////
#define CHECK_CUDA(call) \
do { \
cudaError_t status_ = call; \
if (status_ != cudaSuccess) { \
fprintf(stderr, "CUDA error (%s:%d): %s\n", __FILE__, __LINE__, cudaGetErrorString(status_)); \
exit(1); \
} \
} while (0)
////////////////////////////////////////////////////////////////////////////////////////////////////
// The structure of parameters for the masked multihead attention kernel.
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
template
<
typename
T
>
struct
Multihead_attention_params_base
{
// The output buffer. Dimensions B x D.
T
*
out
=
nullptr
;
// The input Qs and the associated bias. Dimensions B x D and D, resp.
const
T
*
q
=
nullptr
,
*
q_bias
=
nullptr
;
// The input Ks and the associated bias. Dimensions B x D and D, resp.
const
T
*
k
=
nullptr
,
*
k_bias
=
nullptr
;
// The input Vs and the associated bias. Dimensions B x D and D, resp.
const
T
*
v
=
nullptr
,
*
v_bias
=
nullptr
;
// The cache for the Ks. The size must be at least B x L x D.
T
*
k_cache
=
nullptr
;
// The cache for the Vs. The size must be at least B x L x D.
T
*
v_cache
=
nullptr
;
// The indirections to use for cache when beam sampling.
const
int
*
cache_indir
=
nullptr
;
// Stride to handle the case when KQV is a single buffer
int
stride
=
0
;
// The batch size.
int
batch_size
=
0
;
// The beam width
int
beam_width
=
0
;
// The sequence length.
int
memory_max_len
=
0
;
// The number of heads (H).
int
num_heads
=
0
;
// The number of heads for KV cache.
int
num_kv_heads
=
0
;
// The hidden dimension per head (Dh).
int
hidden_size_per_head
=
0
;
// The per-head latent space reserved for rotary embeddings.
int
rotary_embedding_dim
=
0
;
bool
neox_rotary_style
=
false
;
float
rotary_base
=
0.0
f
;
// The maximum length of input sentences.
int
max_input_length
=
0
;
// The current timestep. TODO(bhsueh) Check that do we only this param in cross attention?
int
timestep
=
0
;
// The current timestep of each sentences (support different timestep for different sentences)
// The 1.f / sqrt(Dh). Computed on the host.
float
inv_sqrt_dh
=
0.0
f
;
// Used when we have some input context like gpt
const
int
*
total_padding_tokens
=
nullptr
;
const
bool
*
masked_tokens
=
nullptr
;
const
int
*
prefix_prompt_lengths
=
nullptr
;
int
max_prefix_prompt_length
=
0
;
const
T
*
relative_attention_bias
=
nullptr
;
int
relative_attention_bias_stride
=
0
;
// The slope per head of linear position bias to attention score (H).
const
float
*
linear_bias_slopes
=
nullptr
;
const
T
*
ia3_key_weights
=
nullptr
;
const
T
*
ia3_value_weights
=
nullptr
;
const
int
*
ia3_tasks
=
nullptr
;
const
float
*
qkv_scale_out
=
nullptr
;
const
float
*
attention_out_scale
=
nullptr
;
int
int8_mode
=
0
;
};
template
<
typename
T
,
bool
CROSS_ATTENTION
>
struct
Multihead_attention_params
:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
// will need it here till if constexpr in c++17
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
typename
T
>
struct
Multihead_attention_params
<
T
,
true
>:
public
Multihead_attention_params_base
<
T
>
{
// output cross attentions
float
*
cross_attention_out
=
nullptr
;
int
max_decoder_seq_len
=
0
;
bool
is_return_cross_attentions
=
false
;
// allows to exist attention eary
bool
*
finished
=
nullptr
;
// required in case of cross attention
int
*
memory_length_per_sample
=
nullptr
;
// required in case of masked attention with different length
const
int
*
length_per_sample
=
nullptr
;
};
template
<
class
T
>
using
Masked_multihead_attention_params
=
Multihead_attention_params
<
T
,
false
>
;
template
<
class
T
>
using
Cross_multihead_attention_params
=
Multihead_attention_params
<
T
,
true
>
;
template
<
typename
T
>
struct
outputCrossAttentionParam
{
// max decoder output length
int
max_decoder_seq_len
=
0
;
T
*
cross_attention_out
=
nullptr
;
bool
is_return_cross_attentions
=
false
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
float
>&
params
,
const
cudaStream_t
&
stream
);
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
uint16_t
>&
params
,
const
cudaStream_t
&
stream
);
#ifdef ENABLE_BF16
void
cross_multihead_attention
(
const
Cross_multihead_attention_params
<
__nv_bfloat16
>&
params
,
const
cudaStream_t
&
stream
);
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
awq_cuda/attention/decoder_masked_multihead_attention_template.hpp
deleted
100644 → 0
View file @
3f10cf1d
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "decoder_masked_multihead_attention.h"
#include "decoder_masked_multihead_attention_utils.h"
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <assert.h>
#include <float.h>
#include <type_traits>
// #define MMHA_USE_HMMA_FOR_REDUCTION
// Below are knobs to extend FP32 accumulation for higher FP16 accuracy
// Does not seem to affect the accuracy that much
#define MMHA_USE_FP32_ACUM_FOR_FMA
// Seems to slightly improve the accuracy
#define MMHA_USE_FP32_ACUM_FOR_OUT
#if 0 && defined(MMHA_USE_FP32_ACUM_FOR_OUT)
// Does not seem to improve the accuracy
//#define MMHA_USE_FP32_ACUM_FOR_LOGITS
#endif
namespace
mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
//
// We use the following terminology to describe the different dimensions.
//
// B: Batch size (number of sequences),
// L: Sequence length,
// D: Hidden dimension,
// H: Number of heads,
// Dh: Hidden dimension per head - Dh = D / H.
//
// The different kernels assign a threadblock for B x H pair. The grid has size (1, B, H). We use
// 64, 128 and 256 threads per block.
//
// Each threadblock loads Dh values from Q and its associated bias. The kernels run a loop to
// compute Q * K^T where K is loaded from a cache buffer -- except for the current timestep. The
// cache buffer helps with memory accesses and contains keys with bias.
//
// The layout of the cache buffer for the keys is [B, H, Dh/x, L, x] where x == 8 for FP16 and
// x == 4 for FP32 where the fastest moving dimension (contiguous data) is the rightmost one. The
// values for x are chosen to create chunks of 16 bytes.
//
// The different kernels use 1, 2 or 4 threads per key (THREADS_PER_KEY). The size of the LDGs
// depends on the number of threads per key. Each thread sums Dh / THREADS_PER_KEY elements. At
// the end of each iteration of the Q * K^T loop, we perform a reduction between lanes using an
// HMMA instruction (Tensor Core). Each Q * K^T valuey is stored in shared memory in FP32.
//
// After that loop, a parallel softmax is computed across the different Q * K^T values stored in
// shared memory.
//
// The kernel ends with a loop over the values in V. We use THREADS_PER_VALUE to control how many
// timesteps are computed by loop iteration. As with the keys, the values are read from a cache
// except for the current timestep. The layout of the cache buffer for the values is much simpler
// as it is [B, H, L, Dh].
//
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Dh
>
struct
Qk_vec_
{
};
template
<
>
struct
Qk_vec_
<
float
,
32
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_
<
float
,
64
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_
<
float
,
128
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
float
,
256
>
{
using
Type
=
float4
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
32
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
64
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
128
>
{
using
Type
=
uint2
;
};
template
<
>
struct
Qk_vec_
<
uint16_t
,
256
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
32
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
64
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
128
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
Qk_vec_
<
__nv_bfloat16
,
256
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
K_vec_
{
};
template
<
>
struct
K_vec_
<
float
,
4
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_
<
float
,
1
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
4
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
2
>
{
using
Type
=
uint2
;
};
template
<
>
struct
K_vec_
<
uint16_t
,
1
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
4
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
2
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
K_vec_
<
__nv_bfloat16
,
1
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
V_VEC_SIZE
>
struct
V_vec_
{
};
template
<
>
struct
V_vec_
<
float
,
1
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_
<
float
,
2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_
<
float
,
4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
2
>
{
using
Type
=
uint32_t
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
4
>
{
using
Type
=
uint2
;
};
template
<
>
struct
V_vec_
<
uint16_t
,
8
>
{
using
Type
=
uint4
;
};
#ifdef ENABLE_BF16
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
2
>
{
using
Type
=
__nv_bfloat162
;
};
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
4
>
{
using
Type
=
bf16_4_t
;
};
template
<
>
struct
V_vec_
<
__nv_bfloat16
,
8
>
{
using
Type
=
bf16_8_t
;
};
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
template
<
typename
T
>
struct
Qk_vec_acum_fp32_
{
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
// template<> struct Qk_vec_acum_fp32_<uint16_t> { using Type = float; };
template
<
>
struct
Qk_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
Qk_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
K_vec_acum_fp32_
{
};
template
<
>
struct
K_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
__nv_bfloat16
>
{
using
Type
=
float
;
};
template
<
>
struct
K_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
K_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
K_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
template
<
typename
T
>
struct
V_vec_acum_fp32_
{
};
template
<
>
struct
V_vec_acum_fp32_
<
float
>
{
using
Type
=
float
;
};
template
<
>
struct
V_vec_acum_fp32_
<
float2
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
float4
>
{
using
Type
=
float4
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint32_t
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint2
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
V_vec_acum_fp32_
<
uint4
>
{
using
Type
=
Float8_
;
};
#ifdef ENABLE_BF16
template
<
>
struct
V_vec_acum_fp32_
<
__nv_bfloat162
>
{
using
Type
=
float2
;
};
template
<
>
struct
V_vec_acum_fp32_
<
bf16_4_t
>
{
using
Type
=
Float4_
;
};
template
<
>
struct
V_vec_acum_fp32_
<
bf16_8_t
>
{
using
Type
=
Float8_
;
};
#endif // ENABLE_BF16
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
THREADS_PER_KEY
,
typename
K_vec
,
int
N
>
inline
__device__
float
qk_dot_
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
K_vec_acum
=
typename
K_vec_acum_fp32_
<
K_vec
>::
Type
;
#else
using
K_vec_acum
=
K_vec
;
#endif
// Compute the parallel products for Q*K^T (treat vector lanes separately).
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
K_vec
,
K_vec
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
// Finalize the reduction across lanes.
float
qk
=
sum
(
qk_vec
);
#pragma unroll
for
(
int
mask
=
THREADS_PER_KEY
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk
,
mask
);
}
return
qk
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
THREADS_PER_KEY
>
struct
Qk_dot
{
template
<
typename
K_vec
,
int
N
>
static
inline
__device__
float
dot
(
const
K_vec
(
&
q
)[
N
],
const
K_vec
(
&
k
)[
N
])
{
return
qk_dot_
<
THREADS_PER_KEY
>
(
q
,
k
);
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
hmma_fp32
(
const
uint2
&
a
,
uint32_t
b
)
{
float4
c
;
float
zero
=
0.
f
;
asm
volatile
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32
\n
"
" {%0, %1, %2, %3},
\n
"
" {%4, %5},
\n
"
" {%6},
\n
"
" {%7, %7, %7, %7};
\n
"
:
"=f"
(
c
.
x
),
"=f"
(
c
.
y
),
"=f"
(
c
.
z
),
"=f"
(
c
.
w
)
:
"r"
(
a
.
x
)
"r"
(
a
.
y
),
"r"
(
b
),
"f"
(
zero
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
N
>
inline
__device__
float
qk_hmma_dot_
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
K_vec_acum
=
typename
K_vec_acum_fp32_
<
uint32_t
>::
Type
;
#else
using
K_vec_acum
=
uint32_t
;
#endif
K_vec_acum
qk_vec
=
mul
<
K_vec_acum
,
uint32_t
,
uint32_t
>
(
q
[
0
],
k
[
0
]);
#pragma unroll
for
(
int
ii
=
1
;
ii
<
N
;
++
ii
)
{
qk_vec
=
fma
(
q
[
ii
],
k
[
ii
],
qk_vec
);
}
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
uint32_t
qk_vec_
=
float2_to_half2
(
qk_vec
);
return
hmma_fp32
(
make_uint2
(
qk_vec_
,
0u
),
0x3c003c00u
).
x
;
#else
return
hmma_fp32
(
make_uint2
(
qk_vec
,
0u
),
0x3c003c00u
).
x
;
#endif
#else
return
0.
f
;
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
Qk_dot
<
uint16_t
,
4
>
{
template
<
int
N
>
static
inline
__device__
float
dot
(
const
uint32_t
(
&
q
)[
N
],
const
uint32_t
(
&
k
)[
N
])
{
#if __CUDA_ARCH__ >= 750 && defined(MMHA_USE_HMMA_FOR_REDUCTION)
return
qk_hmma_dot_
(
q
,
k
);
#else
return
qk_dot_
<
4
>
(
q
,
k
);
#endif // defined MMHA_USE_HMMA_FOR_REDUCTION
}
};
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARPS_PER_BLOCK
,
int
WARP_SIZE
=
32
>
inline
__device__
float
block_sum
(
float
*
red_smem
,
float
sum
)
{
// Decompose the thread index into warp / lane.
int
warp
=
threadIdx
.
x
/
WARP_SIZE
;
int
lane
=
threadIdx
.
x
%
WARP_SIZE
;
// Compute the sum per warp.
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Warp leaders store the data to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
sum
;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The warps compute the final sums.
if
(
lane
<
WARPS_PER_BLOCK
)
{
sum
=
red_smem
[
lane
];
}
// Parallel reduction inside the warp.
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
sum
+=
__shfl_xor_sync
(
uint32_t
(
-
1
),
sum
,
mask
);
}
// Broadcast to other threads.
return
__shfl_sync
(
uint32_t
(
-
1
),
sum
,
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float
&
dst
,
float
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint16_t
&
dst
,
float
src
)
{
dst
=
float_to_half
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint32_t
&
dst
,
float2
src
)
{
dst
=
float2_to_half2
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
void
convert_from_float
(
__nv_bfloat16
&
dst
,
float
src
)
{
dst
=
__float2bfloat16
(
src
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
__nv_bfloat162
&
dst
,
float2
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
=
__float22bfloat162_rn
(
src
);
#else
dst
=
__floats2bfloat162_rn
(
src
.
x
,
src
.
y
);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint2
&
dst
,
Float4_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint2
&
dst
,
float4
src
)
{
convert_from_float
(
dst
,
Float4_
{
make_float2
(
src
.
x
,
src
.
y
),
make_float2
(
src
.
z
,
src
.
w
)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
uint4
&
dst
,
Float8_
src
)
{
dst
.
x
=
float2_to_half2
(
src
.
x
);
dst
.
y
=
float2_to_half2
(
src
.
y
);
dst
.
z
=
float2_to_half2
(
src
.
z
);
dst
.
w
=
float2_to_half2
(
src
.
w
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
void
convert_from_float
(
bf16_4_t
&
dst
,
Float4_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
#else
dst
.
x
=
__floats2bfloat162_rn
(
src
.
x
.
x
,
src
.
x
.
y
);
dst
.
y
=
__floats2bfloat162_rn
(
src
.
y
.
x
,
src
.
y
.
y
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
bf16_4_t
&
dst
,
float4
src
)
{
convert_from_float
(
dst
,
Float4_
{
make_float2
(
src
.
x
,
src
.
y
),
make_float2
(
src
.
z
,
src
.
w
)});
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
bf16_8_t
&
dst
,
Float8_
src
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
dst
.
x
=
__float22bfloat162_rn
(
src
.
x
);
dst
.
y
=
__float22bfloat162_rn
(
src
.
y
);
dst
.
z
=
__float22bfloat162_rn
(
src
.
z
);
dst
.
w
=
__float22bfloat162_rn
(
src
.
w
);
#else
dst
.
x
=
__floats2bfloat162_rn
(
src
.
x
.
x
,
src
.
x
.
y
);
dst
.
y
=
__floats2bfloat162_rn
(
src
.
y
.
x
,
src
.
y
.
y
);
dst
.
z
=
__floats2bfloat162_rn
(
src
.
z
.
x
,
src
.
z
.
y
);
dst
.
w
=
__floats2bfloat162_rn
(
src
.
w
.
x
,
src
.
w
.
y
);
#endif
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float2
&
dst
,
float2
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
convert_from_float
(
float4
&
dst
,
float4
src
)
{
dst
=
src
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
convert_to_float
(
float4
u
)
{
return
u
.
x
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
convert_to_float
(
uint4
u
)
{
float2
tmp
=
half2_to_float2
(
u
.
x
);
return
tmp
.
x
;
}
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
cast_to_float
(
float
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
float2
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
cast_to_float
(
float4
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
Float4_
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
cast_to_float
(
Float8_
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
cast_to_float
(
uint32_t
u
)
{
return
half2_to_float2
(
u
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
cast_to_float
(
uint2
u
)
{
Float4_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
return
tmp
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
cast_to_float
(
uint4
u
)
{
Float8_
tmp
;
tmp
.
x
=
half2_to_float2
(
u
.
x
);
tmp
.
y
=
half2_to_float2
(
u
.
y
);
tmp
.
z
=
half2_to_float2
(
u
.
z
);
tmp
.
w
=
half2_to_float2
(
u
.
w
);
return
tmp
;
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
float_from_int8
(
int8_t
u
)
{
return
u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
float_from_int8
(
int16_t
u
)
{
union
{
int16_t
int16
;
int8_t
int8
[
2
];
};
int16
=
u
;
return
make_float2
(
int8
[
0
],
int8
[
1
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
float_from_int8
(
int32_t
u
)
{
union
{
int32_t
int32
;
int8_t
int8
[
4
];
};
int32
=
u
;
return
make_float4
(
int8
[
0
],
int8
[
1
],
int8
[
2
],
int8
[
3
]);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// clang-format off
inline
__device__
Float8_
float_from_int8
(
int64_t
u
)
{
union
{
int64_t
int64
;
int16_t
int16
[
4
];
};
int64
=
u
;
return
Float8_
{
float_from_int8
(
int16
[
0
]),
float_from_int8
(
int16
[
1
]),
float_from_int8
(
int16
[
2
]),
float_from_int8
(
int16
[
3
])};
}
// clang-format on
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int8_t
cast_to_int8
(
float
val
)
{
union
{
int8_t
int8
[
2
];
int16_t
int16
;
};
asm
volatile
(
"cvt.rni.sat.s8.f32 %0, %1;"
:
"=h"
(
int16
)
:
"f"
(
val
));
return
int8
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int32_t
cast_to_int8
(
float4
val
)
{
union
{
int8_t
int8
[
4
];
int32_t
int32
;
};
int8
[
0
]
=
cast_to_int8
(
val
.
x
);
int8
[
1
]
=
cast_to_int8
(
val
.
y
);
int8
[
2
]
=
cast_to_int8
(
val
.
z
);
int8
[
3
]
=
cast_to_int8
(
val
.
w
);
return
int32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
int64_t
cast_to_int8
(
Float8_
val
)
{
union
{
int8_t
int8
[
8
];
int64_t
int64
;
};
int8
[
0
]
=
cast_to_int8
(
val
.
x
.
x
);
int8
[
1
]
=
cast_to_int8
(
val
.
x
.
y
);
int8
[
2
]
=
cast_to_int8
(
val
.
y
.
x
);
int8
[
3
]
=
cast_to_int8
(
val
.
y
.
y
);
int8
[
4
]
=
cast_to_int8
(
val
.
z
.
x
);
int8
[
5
]
=
cast_to_int8
(
val
.
z
.
y
);
int8
[
6
]
=
cast_to_int8
(
val
.
w
.
x
);
int8
[
7
]
=
cast_to_int8
(
val
.
w
.
y
);
return
int64
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
__host__
T
div_up
(
T
m
,
T
n
)
{
return
(
m
+
n
-
1
)
/
n
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
bool
DO_CROSS_ATTENTION
>
inline
size_t
smem_size_in_bytes
(
const
Multihead_attention_params
<
T
,
DO_CROSS_ATTENTION
>&
params
,
int
threads_per_value
,
int
threads_per_block
)
{
// The amount of shared memory needed to store the Q*K^T values in float.
const
int
max_timesteps
=
min
(
params
.
timestep
,
params
.
memory_max_len
);
size_t
qk_sz
=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
16
:
div_up
(
max_timesteps
+
1
,
4
)
*
16
;
// The extra memory needed if we are not using floats for the final logits.
size_t
logits_sz
=
0
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
// TDOD
logits_sz
=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
4
*
sizeof
(
T
)
:
div_up
(
max_timesteps
+
1
,
4
)
*
4
*
sizeof
(
T
);
}
#endif
// The total size needed during softmax.
size_t
softmax_sz
=
qk_sz
+
logits_sz
;
// The number of partial rows to reduce in the final reduction.
int
rows_per_red
=
threads_per_block
/
threads_per_value
;
// The amount of storage needed to finalize the outputs.
size_t
red_sz
=
rows_per_red
*
params
.
hidden_size_per_head
*
sizeof
(
T
)
/
2
;
size_t
transpose_rotary_size
=
0
;
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
transpose_rotary_size
=
2
*
params
.
rotary_embedding_dim
*
sizeof
(
T
);
}
// The max.
return
max
(
max
(
softmax_sz
,
red_sz
),
transpose_rotary_size
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
constexpr
uint32_t
shfl_mask
(
int
threads
)
{
return
threads
==
32
?
uint32_t
(
-
1
)
:
(
1u
<<
threads
)
-
1u
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
// The type of the inputs. Supported types: float and half.
typename
T
,
// The hidden dimension per head.
int
Dh
,
int
Dh_MAX
,
// The number of threads per key.
int
THREADS_PER_KEY
,
// The number of threads per value.
int
THREADS_PER_VALUE
,
// The number of threads in a threadblock.
int
THREADS_PER_BLOCK
,
bool
DO_CROSS_ATTENTION
>
__global__
void
masked_multihead_attention_kernel
(
Multihead_attention_params
<
T
,
DO_CROSS_ATTENTION
>
params
)
{
// Make sure the hidden dimension per head is a multiple of the number of threads per key.
static_assert
(
Dh_MAX
%
THREADS_PER_KEY
==
0
,
""
);
// Make sure the hidden dimension per head is a multiple of the number of threads per value.
static_assert
(
Dh_MAX
%
THREADS_PER_VALUE
==
0
,
""
);
// The size of a warp.
constexpr
int
WARP_SIZE
=
32
;
// The number of warps in a threadblock.
constexpr
int
WARPS_PER_BLOCK
=
THREADS_PER_BLOCK
/
WARP_SIZE
;
// Use smem_size_in_bytes (above) to determine the amount of shared memory.
extern
__shared__
char
smem_
[];
// The shared memory for the Q*K^T values and partial logits in softmax.
float
*
qk_smem
=
reinterpret_cast
<
float
*>
(
smem_
);
// The shared memory for the logits. For FP32, that's the same buffer as qk_smem.
char
*
logits_smem_
=
smem_
;
#ifndef MMHA_USE_FP32_ACUM_FOR_LOGITS
if
(
sizeof
(
T
)
!=
4
)
{
// TODO - change to tlength
const
int
max_timesteps
=
min
(
params
.
timestep
,
params
.
memory_max_len
);
logits_smem_
+=
(
DO_CROSS_ATTENTION
)
?
div_up
(
params
.
memory_max_len
+
1
,
4
)
*
16
:
div_up
(
max_timesteps
+
1
,
4
)
*
16
;
}
T
*
logits_smem
=
reinterpret_cast
<
T
*>
(
logits_smem_
);
#else
float
*
logits_smem
=
reinterpret_cast
<
float
*>
(
logits_smem_
);
#endif
// The shared memory to do the final reduction for the output values. Reuse qk_smem.
T
*
out_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
// The shared memory buffers for the block-wide reductions. One for max, one for sum.
__shared__
float
red_smem
[
WARPS_PER_BLOCK
*
2
];
// A vector of Q or K elements for the current timestep.
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
// Use alignment for safely casting the shared buffers as Qk_vec.
// Shared memory to store Q inputs.
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
q_smem
[
Dh_MAX
];
// This is one of the reasons we should have a separate kernel for cross attention
__shared__
__align__
(
sizeof
(
Qk_vec
))
T
bias_smem
[
DO_CROSS_ATTENTION
?
Dh_MAX
:
1
];
// A vector of Q or K elements for the current timestep.
using
Qk_vec
=
typename
Qk_vec_
<
T
,
Dh_MAX
>::
Type
;
// The number of elements per vector.
constexpr
int
QK_VEC_SIZE
=
sizeof
(
Qk_vec
)
/
sizeof
(
T
);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert
(
Dh_MAX
%
QK_VEC_SIZE
==
0
,
""
);
// We will use block wide reduction if needed
// static_assert(Dh_MAX / QK_VEC_SIZE <= WARP_SIZE, "");
// The number of vectors per warp.
constexpr
int
QK_VECS_PER_WARP
=
Dh_MAX
/
QK_VEC_SIZE
;
// The layout of the cache is [B, H, Dh/x, L, x] with x == 4/8 for FP32/FP16. Since each thread
// owns x elements, we have to decompose the linear index into chunks of x values and the posi-
// tion of the thread in that chunk.
// The number of elements in a chunk of 16B (that's the x in the above formula).
constexpr
int
QK_ELTS_IN_16B
=
16
/
sizeof
(
T
);
// The number of K vectors in 16B.
constexpr
int
QK_VECS_IN_16B
=
16
/
sizeof
(
Qk_vec
);
// The batch/beam idx
const
int
bi
=
blockIdx
.
y
;
if
(
params
.
finished
!=
nullptr
&&
params
.
finished
[
bi
]
==
true
)
{
return
;
}
// The beam idx
const
int
beami
=
bi
%
params
.
beam_width
;
// The "beam-aware" batch idx
const
int
bbi
=
bi
/
params
.
beam_width
;
// The head.
const
int
num_kv_heads
=
params
.
num_kv_heads
;
const
int
kv_rep
=
(
params
.
num_heads
/
num_kv_heads
);
const
int
hi
=
blockIdx
.
x
;
const
int
hi_kv
=
hi
/
kv_rep
;
// Combine the batch and the head indices.
const
int
bhi
=
bi
*
params
.
num_heads
+
hi
;
const
int
bhi_kv
=
bi
*
(
params
.
num_heads
/
kv_rep
)
+
hi_kv
;
// Combine the "beam-aware" batch idx and the head indices.
const
int
bbhi
=
bbi
*
params
.
beam_width
*
params
.
num_heads
+
hi
;
const
int
bbhi_kv
=
bbi
*
params
.
beam_width
*
(
params
.
num_heads
/
kv_rep
)
+
hi_kv
;
// The thread in the block.
const
int
tidx
=
threadIdx
.
x
;
const
bool
handle_kv
=
!
DO_CROSS_ATTENTION
||
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
);
// Every kv_rep threads have the same kv_cache values. So only the first one writes back.
const
int
write_kv_cache
=
handle_kv
&&
(
hi
%
kv_rep
==
0
);
// While doing the product Q*K^T for the different keys we track the max.
float
qk_max
=
-
FLT_MAX
;
float
qk
=
0.0
F
;
// int qkv_base_offset = (params.stride == 0) ? bhi * Dh : bi * params.stride + hi * Dh;
const
int
q_base_offset
=
bi
*
params
.
stride
+
hi
*
Dh
;
const
int
k_base_offset
=
bi
*
params
.
stride
+
hi_kv
*
Dh
;
const
int
v_base_offset
=
k_base_offset
;
const
size_t
bi_seq_len_offset
=
bi
*
params
.
memory_max_len
;
// int tlength = (DO_CROSS_ATTENTION)? params.memory_length_per_sample[bi] - 1 : params.timestep;
int
tlength
=
(
DO_CROSS_ATTENTION
)
?
params
.
memory_length_per_sample
[
bi
]
-
1
:
(
params
.
length_per_sample
==
nullptr
)
?
params
.
timestep
:
params
.
length_per_sample
[
bi
]
+
params
.
max_prefix_prompt_length
;
const
int
first_step
=
max
(
0
,
tlength
+
1
-
params
.
memory_max_len
);
const
int
tlength_circ
=
tlength
%
params
.
memory_max_len
;
// First QK_VECS_PER_WARP load Q and K + the bias values for the current timestep.
const
bool
is_masked
=
tidx
>=
QK_VECS_PER_WARP
;
// The offset in the Q and K buffer also accounts for the batch.
// int qk_offset = qkv_base_offset + tidx * QK_VEC_SIZE;
int
q_offset
=
q_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
k_offset
=
k_base_offset
+
tidx
*
QK_VEC_SIZE
;
int
v_offset
=
k_offset
;
// The offset in the bias buffer.
// int qk_bias_offset = hi * Dh + tidx * QK_VEC_SIZE;
int
q_bias_offset
=
hi
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
k_bias_offset
=
hi_kv
*
Dh
+
tidx
*
QK_VEC_SIZE
;
int
v_bias_offset
=
k_bias_offset
;
const
bool
do_ia3
=
handle_kv
&&
params
.
ia3_tasks
!=
nullptr
;
const
int
ia3_task_id
=
do_ia3
?
params
.
ia3_tasks
[
bbi
]
:
0
;
// Trigger the loads from the Q and K buffers.
Qk_vec
q
;
zero
(
q
);
if
(
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
))
{
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
q_scaling
=
params
.
qkv_scale_out
[
0
];
const
auto
q_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
q
)[
q_offset
]);
convert_from_float
(
q
,
mul
<
Packed_Float_t
,
float
>
(
q_scaling
,
float_from_int8
(
q_quant
)));
}
else
{
q
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q
[
q_offset
]);
}
}
Qk_vec
k
;
zero
(
k
);
if
(
DO_CROSS_ATTENTION
)
{
// The 16B chunk written by the thread.
int
co
=
tidx
/
QK_VECS_IN_16B
;
// The position of the thread in that 16B chunk.
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength
*
QK_ELTS_IN_16B
+
ci
;
k
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
:
k
;
}
else
{
if
(
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
))
{
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
Qk_vec
>::
value
>::
type
;
const
auto
k_scaling
=
params
.
qkv_scale_out
[
1
];
const
auto
k_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
k
)[
k_offset
]);
convert_from_float
(
k
,
mul
<
Packed_Float_t
,
float
>
(
k_scaling
,
float_from_int8
(
k_quant
)));
}
else
{
k
=
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k
[
k_offset
]);
}
}
}
// Trigger the loads from the Q and K bias buffers.
Qk_vec
q_bias
;
zero
(
q_bias
);
q_bias
=
(
!
is_masked
&&
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
q_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
q_bias
[
q_bias_offset
])
:
q_bias
;
Qk_vec
k_bias
;
zero
(
k_bias
);
if
(
handle_kv
)
{
k_bias
=
!
is_masked
&&
(
Dh
==
Dh_MAX
||
tidx
*
QK_VEC_SIZE
<
Dh
)
&&
params
.
k_bias
!=
nullptr
?
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
k_bias
[
k_bias_offset
])
:
k_bias
;
}
// Computes the Q/K values with bias.
q
=
add
(
q
,
q_bias
);
if
(
handle_kv
)
{
k
=
add
(
k
,
k_bias
);
}
if
(
do_ia3
&&
!
is_masked
)
{
k
=
mul
<
Qk_vec
,
Qk_vec
,
Qk_vec
>
(
k
,
*
reinterpret_cast
<
const
Qk_vec
*>
(
&
params
.
ia3_key_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
tidx
*
QK_VEC_SIZE
]));
}
// Padded len
const
int
padd_len
=
(
params
.
total_padding_tokens
==
nullptr
)
?
0
:
params
.
total_padding_tokens
[
bi
];
if
(
params
.
rotary_embedding_dim
>
0
&&
!
params
.
neox_rotary_style
)
{
if
(
handle_kv
)
{
apply_rotary_embedding
(
q
,
k
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
else
{
apply_rotary_embedding
(
q
,
tidx
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
}
}
else
if
(
params
.
rotary_embedding_dim
>
0
&&
params
.
neox_rotary_style
)
{
const
bool
do_rotary
=
!
is_masked
&&
QK_VEC_SIZE
*
tidx
<
params
.
rotary_embedding_dim
;
T
*
q_smem
=
reinterpret_cast
<
T
*>
(
smem_
);
T
*
k_smem
=
q_smem
+
params
.
rotary_embedding_dim
;
const
int
half_rotary_dim
=
params
.
rotary_embedding_dim
/
2
;
const
int
half_idx
=
(
tidx
*
QK_VEC_SIZE
)
/
half_rotary_dim
;
const
int
intra_half_idx
=
(
tidx
*
QK_VEC_SIZE
)
%
half_rotary_dim
;
const
int
smem_pitch
=
half_rotary_dim
;
// TODO: adjust for bank conflicts
assert
(
half_rotary_dim
%
QK_VEC_SIZE
==
0
);
if
(
do_rotary
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
q_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
)
=
q
;
if
(
handle_kv
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
k_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
)
=
k
;
}
}
__syncthreads
();
const
int
transpose_idx
=
half_idx
*
(
half_rotary_dim
/
2
)
+
intra_half_idx
/
2
;
constexpr
int
tidx_factor
=
(
QK_VEC_SIZE
>
1
)
?
QK_VEC_SIZE
/
2
:
1
;
if
(
do_rotary
)
{
mmha
::
vec_from_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
if
(
handle_kv
)
{
mmha
::
vec_from_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
mmha
::
apply_rotary_embedding
(
q
,
k
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
-
padd_len
,
params
.
rotary_base
);
mmha
::
write_smem_transpose
(
k
,
k_smem
,
transpose_idx
,
smem_pitch
);
}
else
{
mmha
::
apply_rotary_embedding
(
q
,
transpose_idx
/
tidx_factor
,
params
.
rotary_embedding_dim
,
tlength
,
params
.
rotary_base
);
}
mmha
::
write_smem_transpose
(
q
,
q_smem
,
transpose_idx
,
smem_pitch
);
}
__syncthreads
();
if
(
do_rotary
)
{
q
=
*
reinterpret_cast
<
Qk_vec
*>
(
q_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
);
if
(
handle_kv
)
{
k
=
*
reinterpret_cast
<
Qk_vec
*>
(
k_smem
+
half_idx
*
smem_pitch
+
intra_half_idx
);
}
}
__syncthreads
();
}
if
(
!
is_masked
)
{
// Store the Q values to shared memory.
*
reinterpret_cast
<
Qk_vec
*>
(
&
q_smem
[
tidx
*
QK_VEC_SIZE
])
=
q
;
// Store Dh values of k_bias into smem, since will need to add later
// if params.timestep == 0
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
bias_smem
[
tidx
*
QK_VEC_SIZE
])
=
k_bias
;
}
// Write the K values to the global memory cache.
//
// NOTE: The stores are uncoalesced as we have multiple chunks of 16B spread across the memory
// system. We designed it this way as it allows much better memory loads (and there are many
// more loads) + the stores are really "write and forget" since we won't need the ack before
// the end of the kernel. There's plenty of time for the transactions to complete.
// The 16B chunk written by the thread.
int
co
=
tidx
/
QK_VECS_IN_16B
;
// The position of the thread in that 16B chunk.
int
ci
=
tidx
%
QK_VECS_IN_16B
*
QK_VEC_SIZE
;
// Two chunks are separated by L * x elements. A thread write QK_VEC_SIZE elements.
int
offset
=
bhi_kv
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
// params.timestep*QK_ELTS_IN_16B +
tlength_circ
*
QK_ELTS_IN_16B
+
ci
;
if
(
write_kv_cache
)
{
// Trigger the stores to global memory.
if
(
Dh
==
Dh_MAX
||
co
<
Dh
/
QK_ELTS_IN_16B
)
{
*
reinterpret_cast
<
Qk_vec
*>
(
&
params
.
k_cache
[
offset
])
=
k
;
}
}
// Compute \sum_i Q[i] * K^T[i] for the current timestep.
#ifdef MMHA_USE_FP32_ACUM_FOR_FMA
using
Qk_vec_acum
=
typename
Qk_vec_acum_fp32_
<
Qk_vec
>::
Type
;
#else
using
Qk_vec_acum
=
Qk_vec
;
#endif
qk
=
dot
<
Qk_vec_acum
,
Qk_vec
>
(
q
,
k
);
if
(
QK_VECS_PER_WARP
<=
WARP_SIZE
)
{
#pragma unroll
for
(
int
mask
=
QK_VECS_PER_WARP
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk
+=
__shfl_xor_sync
(
shfl_mask
(
QK_VECS_PER_WARP
),
qk
,
mask
);
}
}
}
if
(
QK_VECS_PER_WARP
>
WARP_SIZE
)
{
constexpr
int
WARPS_PER_RED
=
(
QK_VECS_PER_WARP
+
WARP_SIZE
-
1
)
/
WARP_SIZE
;
qk
=
block_sum
<
WARPS_PER_RED
>
(
&
red_smem
[
WARPS_PER_RED
],
qk
);
}
// Store that value in shared memory. Keep the Q*K^T value in register for softmax.
if
(
tidx
==
0
)
{
// Normalize qk.
qk
*=
params
.
inv_sqrt_dh
;
if
(
params
.
relative_attention_bias
!=
nullptr
)
{
// TODO (Haotian): check whether we should replace hi with hi_kv,
// although params.relative_attention_bias is usually not used.
qk
=
add
(
qk
,
params
.
relative_attention_bias
[
hi
*
params
.
relative_attention_bias_stride
*
params
.
relative_attention_bias_stride
+
(
tlength
-
padd_len
)
*
params
.
relative_attention_bias_stride
+
(
tlength
-
padd_len
)]);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
// We don't need to apply the linear position bias here since qi - ki = 0 yields the position bias 0.
qk_max
=
qk
;
qk_smem
[
tlength
-
first_step
]
=
qk
;
// qk_smem[params.timestep] = qk;
}
// Make sure the data is in shared memory.
__syncthreads
();
// The type of queries and keys for the math in the Q*K^T product.
using
K_vec
=
typename
K_vec_
<
T
,
THREADS_PER_KEY
>::
Type
;
// The number of elements per vector.
constexpr
int
K_VEC_SIZE
=
sizeof
(
K_vec
)
/
sizeof
(
T
);
// Make sure the hidden size per head is a multiple of the vector size.
static_assert
(
Dh_MAX
%
K_VEC_SIZE
==
0
,
""
);
// The number of elements per thread.
constexpr
int
K_ELTS_PER_THREAD
=
Dh_MAX
/
THREADS_PER_KEY
;
// The number of vectors per thread.
constexpr
int
K_VECS_PER_THREAD
=
K_ELTS_PER_THREAD
/
K_VEC_SIZE
;
// The position the first key loaded by each thread from the cache buffer (for this B * H).
int
ko
=
tidx
/
THREADS_PER_KEY
;
// The position of the thread in the chunk of keys.
int
ki
=
tidx
%
THREADS_PER_KEY
*
K_VEC_SIZE
;
static_assert
(
Dh_MAX
==
THREADS_PER_KEY
*
K_VEC_SIZE
*
K_VECS_PER_THREAD
);
// Load the Q values from shared memory. The values are reused during the loop on K.
K_vec
q_vec
[
K_VECS_PER_THREAD
];
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
q_vec
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
q_smem
[
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
K_vec
k_bias_vec
[
DO_CROSS_ATTENTION
?
K_VECS_PER_THREAD
:
1
];
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
k_bias_vec
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
bias_smem
[
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]);
}
}
// The number of timesteps loaded per iteration.
constexpr
int
K_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_KEY
;
// The number of keys per warp.
constexpr
int
K_PER_WARP
=
WARP_SIZE
/
THREADS_PER_KEY
;
// The base pointer for the key in the cache buffer.
T
*
k_cache
=
&
params
.
k_cache
[
bhi_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
k_cache_batch
=
&
params
.
k_cache
[
bbhi_kv
*
params
.
memory_max_len
*
Dh
+
ki
];
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
int
ti_end
=
div_up
(
tlength
-
first_step
,
K_PER_WARP
)
*
K_PER_WARP
+
first_step
;
// prefix prompt length if has
const
int
prefix_prompt_length
=
(
params
.
prefix_prompt_lengths
==
nullptr
)
?
0
:
params
.
prefix_prompt_lengths
[
bi
];
// Iterate over the keys/timesteps to compute the various (Q*K^T)_{ti} values.
const
bool
has_beams
=
params
.
cache_indir
!=
nullptr
;
const
int
*
beam_indices
=
has_beams
?
&
params
.
cache_indir
[
bi_seq_len_offset
]
:
nullptr
;
for
(
int
ti
=
first_step
+
ko
;
ti
<
ti_end
;
ti
+=
K_PER_ITER
)
{
const
int
ti_circ
=
ti
%
params
.
memory_max_len
;
// The keys loaded from the key cache.
K_vec
k
[
K_VECS_PER_THREAD
];
K_vec
k_vec_zero
;
zero
(
k_vec_zero
);
#pragma unroll
for
(
int
ii
=
0
;
ii
<
K_VECS_PER_THREAD
;
++
ii
)
{
int
jj
=
ii
*
params
.
memory_max_len
+
ti_circ
;
// if( ti < params.timestep ) {
const
bool
within_bounds
=
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
memory_max_len
);
if
(
ti
<
tlength
)
{
if
(
!
within_bounds
)
{
k
[
ii
]
=
k_vec_zero
;
}
else
{
if
(
has_beams
)
{
const
int
beam_offset
=
beam_indices
[
ti_circ
]
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache_batch
[
beam_offset
+
jj
*
QK_ELTS_IN_16B
]);
}
else
{
k
[
ii
]
=
*
reinterpret_cast
<
const
K_vec
*>
(
&
k_cache_batch
[
jj
*
QK_ELTS_IN_16B
]);
}
}
// add bias and update k_cache
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
k
[
ii
]
=
add
(
k
[
ii
],
k_bias_vec
[
ii
]);
if
(
do_ia3
)
{
k
[
ii
]
=
mul
<
K_vec
,
K_vec
,
K_vec
>
(
k
[
ii
],
*
reinterpret_cast
<
const
K_vec
*>
(
&
params
.
ia3_key_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
ki
+
ii
*
THREADS_PER_KEY
*
K_VEC_SIZE
]));
}
if
(
Dh
==
Dh_MAX
||
jj
*
QK_ELTS_IN_16B
<
Dh
*
params
.
memory_max_len
)
{
*
reinterpret_cast
<
K_vec
*>
(
&
k_cache
[
jj
*
QK_ELTS_IN_16B
])
=
k
[
ii
];
}
}
}
}
// Perform the dot product and normalize qk.
//
// WARNING: ALL THE THREADS OF A WARP MUST ENTER!!!
float
qk
=
Qk_dot
<
T
,
THREADS_PER_KEY
>::
dot
(
q_vec
,
k
)
*
params
.
inv_sqrt_dh
;
bool
is_mask
=
(
params
.
masked_tokens
!=
nullptr
)
&&
params
.
masked_tokens
[
bi_seq_len_offset
+
ti
];
// Store the product to shared memory. There's one qk value per timestep. Update the max.
// if( ti < params.timestep && tidx % THREADS_PER_KEY == 0 ) {
if
(
ti
<
tlength
&&
tidx
%
THREADS_PER_KEY
==
0
)
{
if
(
params
.
relative_attention_bias
!=
nullptr
)
{
qk
=
add
(
qk
,
params
.
relative_attention_bias
[
hi
*
params
.
relative_attention_bias_stride
*
params
.
relative_attention_bias_stride
+
tlength
*
params
.
relative_attention_bias_stride
+
ti
]);
}
if
(
params
.
linear_bias_slopes
!=
nullptr
)
{
// Apply the linear position bias: (ki - qi) * slope[hi].
// The padding token locates between the input context and the generated tokens.
// We need to remove the number of padding tokens in the distance computation.
// ti : 0 1 2 3 4 5 6 7 8 9(tlength)
// token: i i i i p p p o o o where i=input, p=pad, o=output.
// e.g. ti = 2, dist = (9 - 3) - 2 = 4.
int
max_context_length
=
params
.
max_prefix_prompt_length
+
params
.
max_input_length
;
float
dist
=
(
ti
<
max_context_length
?
ti
+
padd_len
:
ti
)
-
tlength
;
qk
+=
mul
<
float
,
float
,
float
>
(
params
.
linear_bias_slopes
[
hi
],
dist
);
}
// Add alibi positional encoding
// qk += (alibi_slope != 0) ? alibi_slope * (params.timestep - params.memory_max_len) : 0;
qk_max
=
is_mask
?
qk_max
:
fmaxf
(
qk_max
,
qk
);
qk_smem
[
ti
-
first_step
]
=
qk
;
}
}
// Perform the final reduction to compute the max inside each warp.
//
// NOTE: In a group of THREADS_PER_KEY threads, the leader already has the max value for the
// group so it's not needed to run the reduction inside the group (again).
#pragma unroll
for
(
int
mask
=
WARP_SIZE
/
2
;
mask
>=
THREADS_PER_KEY
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Decompose the thread index into warp and lane.
const
int
warp
=
tidx
/
WARP_SIZE
;
const
int
lane
=
tidx
%
WARP_SIZE
;
// The warp leader writes the max to shared memory.
if
(
lane
==
0
)
{
red_smem
[
warp
]
=
qk_max
;
}
// Make sure the products are in shared memory.
__syncthreads
();
// The warps finalize the reduction.
qk_max
=
lane
<
WARPS_PER_BLOCK
?
red_smem
[
lane
]
:
-
FLT_MAX
;
#pragma unroll
for
(
int
mask
=
WARPS_PER_BLOCK
/
2
;
mask
>=
1
;
mask
/=
2
)
{
qk_max
=
fmaxf
(
qk_max
,
__shfl_xor_sync
(
uint32_t
(
-
1
),
qk_max
,
mask
));
}
// Broadcast to all the threads in the warp.
qk_max
=
__shfl_sync
(
uint32_t
(
-
1
),
qk_max
,
0
);
// Compute the logits and start the sum.
float
sum
=
0.
f
;
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
for
(
int
ti
=
first_step
+
tidx
;
ti
<=
tlength
;
ti
+=
THREADS_PER_BLOCK
)
{
bool
is_mask
=
(
params
.
masked_tokens
!=
nullptr
)
&&
params
.
masked_tokens
[
bi_seq_len_offset
+
ti
];
float
logit
=
is_mask
?
0.
f
:
__expf
(
qk_smem
[
ti
-
first_step
]
-
qk_max
);
sum
+=
logit
;
qk_smem
[
ti
-
first_step
]
=
logit
;
}
// Compute the sum.
sum
=
block_sum
<
WARPS_PER_BLOCK
>
(
&
red_smem
[
WARPS_PER_BLOCK
],
sum
);
// Normalize the logits.
float
inv_sum
=
__fdividef
(
1.
f
,
sum
+
1.e-6
f
);
// for( int ti = tidx; ti <= params.timestep; ti += THREADS_PER_BLOCK ) {
const
size_t
cross_attention_out_offset
=
params
.
is_return_cross_attentions
?
bhi_kv
*
params
.
max_decoder_seq_len
*
params
.
memory_max_len
+
params
.
timestep
*
params
.
memory_max_len
:
0
;
for
(
int
ti
=
first_step
+
tidx
;
ti
<=
tlength
;
ti
+=
THREADS_PER_BLOCK
)
{
float
logit
=
qk_smem
[
ti
-
first_step
]
*
inv_sum
;
if
(
params
.
is_return_cross_attentions
)
{
params
.
cross_attention_out
[
cross_attention_out_offset
+
ti
]
=
logit
;
}
convert_from_float
(
logits_smem
[
ti
-
first_step
],
logit
);
}
// Put Values part below so we leverage __syncthreads
// from the previous step
// The number of elements per vector.
constexpr
int
V_VEC_SIZE
=
Dh_MAX
/
THREADS_PER_VALUE
;
// A vector of V elements for the current timestep.
using
V_vec
=
typename
V_vec_
<
T
,
V_VEC_SIZE
>::
Type
;
// The value computed by this thread.
int
vo
=
tidx
/
THREADS_PER_VALUE
;
// The hidden dimensions computed by this particular thread.
int
vi
=
tidx
%
THREADS_PER_VALUE
*
V_VEC_SIZE
;
// The base pointer for the value in the cache buffer.
T
*
v_cache
=
&
params
.
v_cache
[
bhi_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
T
*
v_cache_batch
=
&
params
.
v_cache
[
bbhi_kv
*
params
.
memory_max_len
*
Dh
+
vi
];
// The number of values processed per iteration of the loop.
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
// One group of threads computes the product(s) for the current timestep.
V_vec
v_bias
;
zero
(
v_bias
);
// if( vo == params.timestep % V_PER_ITER ) {
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
if
(
handle_kv
)
{
if
(
vo
==
tlength
%
V_PER_ITER
)
{
// Trigger the loads from the V bias buffer.
if
(
params
.
v_bias
!=
nullptr
)
{
v_bias
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v_bias
[
hi_kv
*
Dh
+
vi
]);
}
if
(
DO_CROSS_ATTENTION
)
{
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
])
=
v_bias
;
}
}
}
}
// From previous, before values, step
// Also make sure the logits are in shared memory.
__syncthreads
();
// Values continued
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
using
V_vec_acum
=
typename
V_vec_acum_fp32_
<
V_vec
>::
Type
;
#else
using
V_vec_acum
=
V_vec
;
#endif
// The partial outputs computed by each thread.
V_vec_acum
out
;
zero
(
out
);
// Loop over the timesteps to compute the partial outputs.
// for( int ti = vo; ti < params.timestep; ti += V_PER_ITER ) {
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
for
(
int
ti
=
first_step
+
vo
;
ti
<
tlength
;
ti
+=
V_PER_ITER
)
{
const
int
ti_circ
=
ti
%
params
.
memory_max_len
;
// Fetch offset based on cache_indir when beam sampling
const
int
beam_src
=
(
params
.
cache_indir
!=
nullptr
)
?
params
.
cache_indir
[
bi_seq_len_offset
+
ti_circ
]
:
0
;
const
int
beam_offset
=
beam_src
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
// Load the values from the cache.
V_vec
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache_batch
[
beam_offset
+
ti_circ
*
Dh
]);
if
(
DO_CROSS_ATTENTION
&&
params
.
timestep
==
0
)
{
v
=
add
(
v
,
*
reinterpret_cast
<
V_vec
*>
(
&
bias_smem
[
vi
]));
if
(
do_ia3
)
{
v
=
mul
<
V_vec
,
V_vec
,
V_vec
>
(
v
,
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
ia3_value_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
vi
]));
}
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
ti
*
Dh
])
=
v
;
}
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
float
logit
=
logits_smem
[
ti
-
first_step
];
out
=
fma
(
logit
,
cast_to_float
(
v
),
out
);
#else
T
logit
=
logits_smem
[
ti
-
first_step
];
// Update the partial sums.
out
=
fma
(
logit
,
v
,
out
);
#endif
}
}
// One group of threads computes the product(s) for the current timestep.
// if( vo == params.timestep % V_PER_ITER ) {
if
(
vo
==
tlength
%
V_PER_ITER
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
V_vec
v
;
if
(
DO_CROSS_ATTENTION
)
{
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
v_cache
[
tlength
*
Dh
]);
}
else
{
// Trigger the loads from the V buffer.
const
auto
v_offset
=
v_base_offset
+
vi
;
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
V_vec
>::
value
>::
type
;
const
auto
v_scaling
=
params
.
qkv_scale_out
[
2
];
const
auto
v_quant
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
reinterpret_cast
<
const
int8_t
*>
(
params
.
v
)[
v_offset
]);
convert_from_float
(
v
,
mul
<
Packed_Float_t
,
float
>
(
v_scaling
,
float_from_int8
(
v_quant
)));
}
else
{
v
=
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
v
[
v_offset
]);
}
// Trigger the loads from the V bias buffer.
// V_vec v_bias = *reinterpret_cast<const V_vec*>(¶ms.v_bias[hi*Dh + vi]);
}
// Compute the V values with bias.
v
=
add
(
v
,
v_bias
);
if
(
write_kv_cache
)
{
if
(
do_ia3
)
{
v
=
mul
<
V_vec
,
V_vec
,
V_vec
>
(
v
,
*
reinterpret_cast
<
const
V_vec
*>
(
&
params
.
ia3_value_weights
[(
ia3_task_id
*
params
.
num_heads
+
hi
)
*
Dh
+
vi
]));
}
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec*>(&v_cache[params.timestep*Dh]) = v;
*
reinterpret_cast
<
V_vec
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
v
;
}
// Initialize the output value with the current timestep.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
// out = fma(logits_smem[params.timestep], cast_to_float(v), out);
out
=
fma
(
logits_smem
[
tlength
-
first_step
],
cast_to_float
(
v
),
out
);
#else
// out = fma(logits_smem[params.timestep], v, out);
out
=
fma
(
logits_smem
[
tlength
-
first_step
],
v
,
out
);
#endif
}
// Make sure we can start writing to shared memory.
__syncthreads
();
// Run the final reduction amongst the different groups computing different partial outputs.
if
(
Dh
==
Dh_MAX
||
vi
<
Dh
)
{
#pragma unroll
for
(
int
active_groups
=
V_PER_ITER
;
active_groups
>=
2
;
active_groups
/=
2
)
{
// The midpoint in the number of active groups.
int
midpoint
=
active_groups
/
2
;
// The upper part of active threads store to shared memory.
if
(
vo
>=
midpoint
&&
vo
<
active_groups
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
]),
out
);
#else
*
reinterpret_cast
<
V_vec
*>
(
&
out_smem
[(
vo
-
midpoint
)
*
Dh
+
vi
])
=
out
;
#endif
}
__syncthreads
();
// The bottom warps update their values.
if
(
vo
<
midpoint
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
out
=
add
(
*
reinterpret_cast
<
const
V_vec
*>
(
&
out_smem
[
vo
*
Dh
+
vi
]),
out
);
}
__syncthreads
();
}
}
// Output the final values.
if
(
vo
==
0
&&
(
Dh
==
Dh_MAX
||
vi
<
Dh
))
{
#ifdef MMHA_USE_FP32_ACUM_FOR_OUT
if
(
params
.
int8_mode
==
2
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec_acum
>::
value
>::
type
;
out
=
mul
<
V_vec_acum
,
float
>
(
*
params
.
attention_out_scale
,
out
);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
(
reinterpret_cast
<
int8_t
*>
(
params
.
out
)[
bhi
*
Dh
+
vi
]))
=
cast_to_int8
(
out
);
}
else
{
convert_from_float
(
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
]),
out
);
}
#else
// TODO: support int8_mode?
*
reinterpret_cast
<
V_vec
*>
(
&
params
.
out
[
bhi
*
Dh
+
vi
])
=
out
;
#endif
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace mmha
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
Dh
,
int
Dh_MAX
,
typename
KERNEL_PARAMS_TYPE
>
void
mmha_launch_kernel
(
const
KERNEL_PARAMS_TYPE
&
params
,
const
cudaStream_t
&
stream
);
awq_cuda/attention/decoder_masked_multihead_attention_utils.h
deleted
100644 → 0
View file @
3f10cf1d
// Downloaded from from FasterTransformer v5.2.1
// https://github.com/NVIDIA/FasterTransformer/blob/release/v5.2.1_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention_utils.h
/*
* Copyright (c) 2020-2022, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cuda_bf16_wrapper.h"
#include "cuda_bf16_fallbacks.cuh"
#include <stdint.h>
using
namespace
fastertransformer
;
namespace
mmha
{
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Float8_
{
float2
x
;
float2
y
;
float2
z
;
float2
w
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
Float4_
{
float2
x
;
float2
y
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
struct
bf16_4_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
struct
bf16_8_t
{
__nv_bfloat162
x
;
__nv_bfloat162
y
;
__nv_bfloat162
z
;
__nv_bfloat162
w
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
struct
num_elems
;
template
<
>
struct
num_elems
<
float
>
{
static
constexpr
int
value
=
1
;
};
template
<
>
struct
num_elems
<
float2
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
float4
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
Float4_
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
Float8_
>
{
static
constexpr
int
value
=
8
;
};
template
<
>
struct
num_elems
<
uint32_t
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
uint2
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
uint4
>
{
static
constexpr
int
value
=
8
;
};
#ifdef ENABLE_BF16
template
<
>
struct
num_elems
<
__nv_bfloat162
>
{
static
constexpr
int
value
=
2
;
};
template
<
>
struct
num_elems
<
bf16_4_t
>
{
static
constexpr
int
value
=
4
;
};
template
<
>
struct
num_elems
<
bf16_8_t
>
{
static
constexpr
int
value
=
8
;
};
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
,
int
N
>
struct
packed_type
;
template
<
typename
T
>
struct
packed_type
<
T
,
1
>
{
using
type
=
T
;
};
template
<
>
struct
packed_type
<
int8_t
,
2
>
{
using
type
=
int16_t
;
};
template
<
>
struct
packed_type
<
int8_t
,
4
>
{
using
type
=
int32_t
;
};
template
<
>
struct
packed_type
<
int8_t
,
8
>
{
using
type
=
int64_t
;
};
template
<
>
struct
packed_type
<
float
,
2
>
{
using
type
=
float2
;
};
template
<
>
struct
packed_type
<
float
,
4
>
{
using
type
=
float4
;
};
template
<
>
struct
packed_type
<
float
,
8
>
{
using
type
=
Float8_
;
};
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
float
b
)
{
return
a
+
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
add
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
add
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat16
add
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
return
a
+
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__nv_bfloat162
add
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
bf16hadd2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
add
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
add
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint16_t
add
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"add.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
add
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"add.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
add
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
add
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
add
(
a
.
x
,
b
.
x
);
c
.
y
=
add
(
a
.
y
,
b
.
y
);
c
.
z
=
add
(
a
.
z
,
b
.
z
);
c
.
w
=
add
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint16_t
float_to_half
(
float
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if 0 && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800 // Is it better?
float zero = 0.f;
asm volatile("cvt.rn.f16x2.f32 %0, %1, %2;\n" : "=r"(tmp.u32) : "f"(zero), "f"(f));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
));
#endif
return
tmp
.
u16
[
0
];
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
float2_to_half2
(
float2
f
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
asm
volatile
(
"cvt.rn.f16x2.f32 %0, %1, %2;
\n
"
:
"=r"
(
tmp
.
u32
)
:
"f"
(
f
.
y
),
"f"
(
f
.
x
));
#else
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
0
])
:
"f"
(
f
.
x
));
asm
volatile
(
"cvt.rn.f16.f32 %0, %1;
\n
"
:
"=h"
(
tmp
.
u16
[
1
])
:
"f"
(
f
.
y
));
#endif
return
tmp
.
u32
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
half_to_float
(
uint16_t
h
)
{
float
f
;
asm
volatile
(
"cvt.f32.f16 %0, %1;
\n
"
:
"=f"
(
f
)
:
"h"
(
h
));
return
f
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
half2_to_float2
(
uint32_t
v
)
{
uint16_t
lo
,
hi
;
asm
volatile
(
"mov.b32 {%0, %1}, %2;
\n
"
:
"=h"
(
lo
),
"=h"
(
hi
)
:
"r"
(
v
));
return
make_float2
(
half_to_float
(
lo
),
half_to_float
(
hi
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
add
(
float
a
,
uint16_t
b
)
{
return
a
+
half_to_float
(
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
add
(
float
a
,
__nv_bfloat16
b
)
{
return
a
+
__bfloat162float
(
b
);
}
#endif
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
add
(
uint32_t
a
,
float2
fb
)
{
float2
fa
=
half2_to_float2
(
a
);
return
add
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
add
(
uint2
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
add
(
uint4
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
h0_h0
(
uint16_t
a
)
{
uint32_t
b
;
asm
volatile
(
"mov.b32 %0, {%1, %1};"
:
"=r"
(
b
)
:
"h"
(
a
));
return
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
float
a
,
float
b
,
float
c
)
{
return
a
*
b
+
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
float2
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
float
a
,
float2
b
,
float2
c
)
{
float2
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
fma
(
float4
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float4
fma
(
float
a
,
float4
b
,
float4
c
)
{
float4
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
float
a
,
Float4_
b
,
Float4_
c
)
{
Float4_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
float
a
,
Float8_
b
,
Float8_
c
)
{
Float8_
d
;
d
.
x
=
fma
(
a
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float2
add
(
__nv_bfloat162
a
,
float2
fb
)
{
float2
fa
=
bf1622float2
(
a
);
return
add
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
add
(
bf16_4_t
a
,
Float4_
fb
)
{
Float4_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
add
(
bf16_8_t
a
,
Float8_
fb
)
{
Float8_
fc
;
fc
.
x
=
add
(
a
.
x
,
fb
.
x
);
fc
.
y
=
add
(
a
.
y
,
fb
.
y
);
fc
.
z
=
add
(
a
.
z
,
fb
.
z
);
fc
.
w
=
add
(
a
.
w
,
fb
.
w
);
return
fc
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
fma
(
uint32_t
a
,
uint32_t
b
,
uint32_t
c
)
{
uint32_t
d
;
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
d
)
:
"r"
(
a
),
"r"
(
b
),
"r"
(
c
));
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint32_t
fma
(
uint16_t
a
,
uint32_t
b
,
uint32_t
c
)
{
return
fma
(
h0_h0
(
a
),
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
fma
(
uint2
a
,
uint2
b
,
uint2
c
)
{
uint2
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint2
fma
(
uint16_t
a
,
uint2
b
,
uint2
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
fma
(
uint4
a
,
uint4
b
,
uint4
c
)
{
uint4
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
uint4
fma
(
uint16_t
a
,
uint4
b
,
uint4
c
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
uint16_t
a
,
uint16_t
b
,
float
fc
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
+
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
uint32_t
a
,
uint32_t
b
,
float2
fc
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
uint16_t
a
,
uint32_t
b
,
float2
fc
)
{
return
fma
(
h0_h0
(
a
),
b
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
uint2
a
,
uint2
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
uint16_t
a
,
uint2
b
,
Float4_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
uint4
a
,
uint4
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
uint16_t
a
,
uint4
b
,
Float8_
fc
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
bf16hfma2
(
a
,
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
__nv_bfloat162
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
__nv_bfloat162
c
)
{
return
bf16hfma2
(
bf162bf162
(
a
),
b
,
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
bf16_4_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_4_t
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
bf16_4_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
bf16_8_t
d
;
d
.
x
=
fma
(
a
.
x
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
a
.
y
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
a
.
z
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
a
.
w
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
bf16_8_t
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
bf16_8_t
c
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
d
;
d
.
x
=
fma
(
s
,
b
.
x
,
c
.
x
);
d
.
y
=
fma
(
s
,
b
.
y
,
c
.
y
);
d
.
z
=
fma
(
s
,
b
.
z
,
c
.
z
);
d
.
w
=
fma
(
s
,
b
.
w
,
c
.
w
);
return
d
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
fma
(
__nv_bfloat16
a
,
__nv_bfloat16
b
,
float
fc
)
{
return
__bfloat162float
(
a
)
*
__bfloat162float
(
b
)
+
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
__nv_bfloat162
a
,
__nv_bfloat162
b
,
float2
fc
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
fma
(
fa
,
fb
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
fma
(
__nv_bfloat16
a
,
__nv_bfloat162
b
,
float2
fc
)
{
return
fma
(
bf162bf162
(
a
),
b
,
fc
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
bf16_4_t
a
,
bf16_4_t
b
,
Float4_
fc
)
{
Float4_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float4_
fma
(
__nv_bfloat16
a
,
bf16_4_t
b
,
Float4_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
bf16_8_t
a
,
bf16_8_t
b
,
Float8_
fc
)
{
Float8_
fd
;
fd
.
x
=
fma
(
a
.
x
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
a
.
y
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
a
.
z
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
a
.
w
,
b
.
w
,
fc
.
w
);
return
fd
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
Float8_
fma
(
__nv_bfloat16
a
,
bf16_8_t
b
,
Float8_
fc
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fd
;
fd
.
x
=
fma
(
s
,
b
.
x
,
fc
.
x
);
fd
.
y
=
fma
(
s
,
b
.
y
,
fc
.
y
);
fd
.
z
=
fma
(
s
,
b
.
z
,
fc
.
z
);
fd
.
w
=
fma
(
s
,
b
.
w
,
fc
.
w
);
return
fd
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Acc
,
typename
A
,
typename
B
>
inline
__device__
Acc
mul
(
A
a
,
B
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
<
float
,
float
>
(
float
a
,
float
b
)
{
return
a
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
float2
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
float
a
,
float2
b
)
{
float2
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float4
mul
(
float4
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
.
x
*
b
.
x
;
c
.
y
=
a
.
y
*
b
.
y
;
c
.
z
=
a
.
z
*
b
.
z
;
c
.
w
=
a
.
w
*
b
.
w
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float4
mul
(
float
a
,
float4
b
)
{
float4
c
;
c
.
x
=
a
*
b
.
x
;
c
.
y
=
a
*
b
.
y
;
c
.
z
=
a
*
b
.
z
;
c
.
w
=
a
*
b
.
w
;
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
float
a
,
Float8_
b
)
{
Float8_
c
;
c
.
x
=
make_float2
(
a
*
b
.
x
.
x
,
a
*
b
.
x
.
y
);
c
.
y
=
make_float2
(
a
*
b
.
y
.
x
,
a
*
b
.
y
.
y
);
c
.
z
=
make_float2
(
a
*
b
.
z
.
x
,
a
*
b
.
z
.
y
);
c
.
w
=
make_float2
(
a
*
b
.
w
.
x
,
a
*
b
.
w
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint16_t
mul
(
uint16_t
a
,
uint16_t
b
)
{
uint16_t
c
;
asm
volatile
(
"mul.f16 %0, %1, %2;
\n
"
:
"=h"
(
c
)
:
"h"
(
a
),
"h"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint32_t
mul
(
uint32_t
a
,
uint32_t
b
)
{
uint32_t
c
;
asm
volatile
(
"mul.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
c
)
:
"r"
(
a
),
"r"
(
b
));
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint32_t
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint2
mul
(
uint2
a
,
uint2
b
)
{
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint2
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint2
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint4
mul
(
uint4
a
,
uint4
b
)
{
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
uint4
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
uint4
c
;
c
.
x
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
uint32_t
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
uint16_t
b
)
{
float
fa
=
half_to_float
(
a
);
float
fb
=
half_to_float
(
b
);
return
fa
*
fb
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
uint16_t
a
,
float
b
)
{
return
half_to_float
(
a
)
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
uint32_t
a
,
uint32_t
b
)
{
float2
fa
=
half2_to_float2
(
a
);
float2
fb
=
half2_to_float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
uint16_t
a
,
uint32_t
b
)
{
return
mul
<
float2
,
uint32_t
,
uint32_t
>
(
h0_h0
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
uint2
a
,
uint2
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
uint16_t
a
,
uint2
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
uint4
a
,
uint4
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
uint16_t
a
,
uint4
b
)
{
uint32_t
s
=
h0_h0
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
uint32_t
,
uint32_t
>
(
s
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
template
<
>
inline
__device__
__nv_bfloat16
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
return
__hmul
(
a
,
b
);
#else
return
bf16hmul
(
a
,
b
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
return
bf16hmul2
(
a
,
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
__nv_bfloat162
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_4_t
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_4_t
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_4_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_8_t
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
bf16_8_t
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
bf16_8_t
c
;
c
.
x
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
c
.
y
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
c
.
z
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
c
.
w
=
mul
<
__nv_bfloat162
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
c
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
__nv_bfloat16
b
)
{
float
fa
=
(
float
)
a
;
float
fb
=
(
float
)
b
;
return
fa
*
fb
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float
mul
(
__nv_bfloat16
a
,
float
b
)
{
return
__bfloat162float
(
a
)
*
b
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat162
a
,
__nv_bfloat162
b
)
{
float2
fa
=
bf1622float2
(
a
);
float2
fb
=
bf1622float2
(
b
);
return
mul
<
float2
,
float2
,
float2
>
(
fa
,
fb
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
float2
mul
(
__nv_bfloat16
a
,
__nv_bfloat162
b
)
{
return
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
bf162bf162
(
a
),
b
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
bf16_4_t
a
,
bf16_4_t
b
)
{
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float4_
mul
(
__nv_bfloat16
a
,
bf16_4_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float4_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
bf16_8_t
a
,
bf16_8_t
b
)
{
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
x
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
y
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
z
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
a
.
w
,
b
.
w
);
return
fc
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
inline
__device__
Float8_
mul
(
__nv_bfloat16
a
,
bf16_8_t
b
)
{
__nv_bfloat162
s
=
bf162bf162
(
a
);
Float8_
fc
;
fc
.
x
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
x
);
fc
.
y
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
y
);
fc
.
z
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
z
);
fc
.
w
=
mul
<
float2
,
__nv_bfloat162
,
__nv_bfloat162
>
(
s
,
b
.
w
);
return
fc
;
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float
v
)
{
return
v
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float2
v
)
{
return
v
.
x
+
v
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
float4
v
)
{
return
v
.
x
+
v
.
y
+
v
.
z
+
v
.
w
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
#ifdef ENABLE_BF16
inline
__device__
float
sum
(
__nv_bfloat162
v
)
{
float2
vf
=
bf1622float2
(
v
);
return
vf
.
x
+
vf
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
bf16_4_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
bf16_8_t
v
)
{
return
sum
(
v
.
x
)
+
sum
(
v
.
y
)
+
sum
(
v
.
z
)
+
sum
(
v
.
w
);
}
#endif // ENABLE_BF16
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint16_t
v
)
{
return
half_to_float
(
v
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint32_t
v
)
{
float2
tmp
=
half2_to_float2
(
v
);
return
tmp
.
x
+
tmp
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint2
v
)
{
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
return
sum
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
uint4
v
)
{
#if 1
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
c
=
add
(
c
,
v
.
z
);
c
=
add
(
c
,
v
.
w
);
#else
uint32_t
c
=
add
(
v
.
x
,
v
.
y
);
uint32_t
d
=
add
(
v
.
z
,
v
.
w
);
c
=
add
(
c
,
d
);
#endif
return
sum
(
c
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
Float4_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float
sum
(
Float8_
v
)
{
return
v
.
x
.
x
+
v
.
x
.
y
+
v
.
y
.
x
+
v
.
y
.
y
+
v
.
z
.
x
+
v
.
z
.
y
+
v
.
w
.
x
+
v
.
w
.
y
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
T
,
T
,
T
>
(
a
,
b
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
A
,
typename
T
>
inline
__device__
float
dot
(
T
a
,
T
b
)
{
return
sum
(
mul
<
A
,
T
,
T
>
(
a
,
b
));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
void
zero
(
uint16_t
&
dst
)
{
dst
=
uint16_t
(
0
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T
>
inline
__device__
void
zero
(
T
&
dst
)
{
constexpr
int
WORDS
=
sizeof
(
T
)
/
4
;
union
{
T
raw
;
uint32_t
words
[
WORDS
];
}
tmp
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
WORDS
;
++
ii
)
{
tmp
.
words
[
ii
]
=
0u
;
}
dst
=
tmp
.
raw
;
}
////////////////////////////////////////////////////////////////////////////////////////////////////
inline
__device__
float2
rotary_embedding_coefficient
(
const
int
zid
,
const
int
rot_embed_dim
,
const
float
t_step
,
const
float
base
)
{
const
float
inv_freq
=
t_step
/
pow
(
base
,
zid
/
(
float
)
rot_embed_dim
);
return
{
cos
(
inv_freq
),
sin
(
inv_freq
)};
}
inline
__device__
float2
rotary_embedding_transform
(
const
float2
v
,
const
float2
coef
)
{
float2
rot_v
;
rot_v
.
x
=
coef
.
x
*
v
.
x
-
coef
.
y
*
v
.
y
;
rot_v
.
y
=
coef
.
x
*
v
.
y
+
coef
.
y
*
v
.
x
;
return
rot_v
;
}
inline
__device__
uint32_t
rotary_embedding_transform
(
const
uint32_t
v
,
const
float2
coef
)
{
float2
fv
=
half2_to_float2
(
v
);
float2
rot_fv
=
rotary_embedding_transform
(
fv
,
coef
);
return
float2_to_half2
(
rot_fv
);
}
#ifdef ENABLE_BF16
inline
__device__
__nv_bfloat162
rotary_embedding_transform
(
const
__nv_bfloat162
v
,
const
float2
coef
)
{
float2
fv
=
bf1622float2
(
v
);
float2
rot_fv
=
rotary_embedding_transform
(
fv
,
coef
);
return
__floats2bfloat162_rn
(
rot_fv
.
x
,
rot_fv
.
y
);
}
#endif
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float
&
q
,
float
&
k
,
int
zid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
return
;
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float2
&
q
,
float2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
float4
&
q
,
float4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
Float4_
&
q_
=
*
reinterpret_cast
<
Float4_
*>
(
&
q
);
Float4_
&
k_
=
*
reinterpret_cast
<
Float4_
*>
(
&
k
);
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q_
.
x
=
rotary_embedding_transform
(
q_
.
x
,
coef0
);
k_
.
x
=
rotary_embedding_transform
(
k_
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q_
.
y
=
rotary_embedding_transform
(
q_
.
y
,
coef1
);
k_
.
y
=
rotary_embedding_transform
(
k_
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint32_t
&
q
,
uint32_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint2
&
q
,
uint2
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
uint4
&
q
,
uint4
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#ifdef ENABLE_BF16
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
__nv_bfloat162
&
q
,
__nv_bfloat162
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
2
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef
=
rotary_embedding_coefficient
(
2
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
=
rotary_embedding_transform
(
q
,
coef
);
k
=
rotary_embedding_transform
(
k
,
coef
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_4_t
&
q
,
bf16_4_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
4
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
4
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
4
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
}
inline
__device__
void
apply_rotary_embedding
(
bf16_8_t
&
q
,
bf16_8_t
&
k
,
int
tid
,
int
rot_embed_dim
,
int
t_step
,
const
float
base
=
10000.0
f
)
{
if
(
8
*
tid
>=
rot_embed_dim
)
{
return
;
}
const
auto
coef0
=
rotary_embedding_coefficient
(
8
*
tid
,
rot_embed_dim
,
t_step
,
base
);
q
.
x
=
rotary_embedding_transform
(
q
.
x
,
coef0
);
k
.
x
=
rotary_embedding_transform
(
k
.
x
,
coef0
);
const
auto
coef1
=
rotary_embedding_coefficient
(
8
*
tid
+
2
,
rot_embed_dim
,
t_step
,
base
);
q
.
y
=
rotary_embedding_transform
(
q
.
y
,
coef1
);
k
.
y
=
rotary_embedding_transform
(
k
.
y
,
coef1
);
const
auto
coef2
=
rotary_embedding_coefficient
(
8
*
tid
+
4
,
rot_embed_dim
,
t_step
,
base
);
q
.
z
=
rotary_embedding_transform
(
q
.
z
,
coef2
);
k
.
z
=
rotary_embedding_transform
(
k
.
z
,
coef2
);
const
auto
coef3
=
rotary_embedding_coefficient
(
8
*
tid
+
6
,
rot_embed_dim
,
t_step
,
base
);
q
.
w
=
rotary_embedding_transform
(
q
.
w
,
coef3
);
k
.
w
=
rotary_embedding_transform
(
k
.
w
,
coef3
);
}
#endif // ENABLE_BF16
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
vec_from_smem_transpose
(
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
return
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint32_t
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
tmp
.
u16
[
0
]
=
smem
[
transpose_idx
];
tmp
.
u16
[
1
]
=
smem
[
smem_pitch
+
transpose_idx
];
vec
=
tmp
.
u32
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint2
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
union
{
uint2
u32x2
;
uint16_t
u16
[
4
];
}
tmp_3
;
tmp_3
.
u16
[
0
]
=
tmp_1
.
u16
[
0
];
tmp_3
.
u16
[
1
]
=
tmp_2
.
u16
[
0
];
tmp_3
.
u16
[
2
]
=
tmp_1
.
u16
[
1
];
tmp_3
.
u16
[
3
]
=
tmp_2
.
u16
[
1
];
vec
=
tmp_3
.
u32x2
;
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint4
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
uint16_t
u16
[
4
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
union
{
uint4
u32x4
;
uint16_t
u16
[
8
];
}
tmp_3
;
tmp_3
.
u16
[
0
]
=
tmp_1
.
u16
[
0
];
tmp_3
.
u16
[
1
]
=
tmp_2
.
u16
[
0
];
tmp_3
.
u16
[
2
]
=
tmp_1
.
u16
[
1
];
tmp_3
.
u16
[
3
]
=
tmp_2
.
u16
[
1
];
tmp_3
.
u16
[
4
]
=
tmp_1
.
u16
[
2
];
tmp_3
.
u16
[
5
]
=
tmp_2
.
u16
[
2
];
tmp_3
.
u16
[
6
]
=
tmp_1
.
u16
[
3
];
tmp_3
.
u16
[
7
]
=
tmp_2
.
u16
[
3
];
vec
=
tmp_3
.
u32x4
;
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
bf16_4_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
__nv_bfloat16
bf16
[
2
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u32
=
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
vec
.
x
=
__nv_bfloat162
{
tmp_1
.
bf16
[
0
],
tmp_2
.
bf16
[
0
]};
vec
.
y
=
__nv_bfloat162
{
tmp_1
.
bf16
[
1
],
tmp_2
.
bf16
[
1
]};
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
bf16_8_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
__nv_bfloat16
bf16
[
4
];
}
tmp_1
,
tmp_2
;
tmp_1
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
]);
tmp_2
.
u64
=
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
]);
vec
.
x
=
__nv_bfloat162
{
tmp_1
.
bf16
[
0
],
tmp_2
.
bf16
[
0
]};
vec
.
y
=
__nv_bfloat162
{
tmp_1
.
bf16
[
1
],
tmp_2
.
bf16
[
1
]};
vec
.
z
=
__nv_bfloat162
{
tmp_1
.
bf16
[
2
],
tmp_2
.
bf16
[
2
]};
vec
.
w
=
__nv_bfloat162
{
tmp_1
.
bf16
[
3
],
tmp_2
.
bf16
[
3
]};
}
#endif // ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float4
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
z
=
smem
[
transpose_idx
+
1
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
vec
.
w
=
smem
[
smem_pitch
+
transpose_idx
+
1
];
}
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
uint32_t
&
vec
,
half
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u16
[
0
]
=
smem
[
transpose_idx
];
tmp
.
u16
[
1
]
=
smem
[
smem_pitch
+
transpose_idx
];
vec
=
tmp
.
u32
;
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
__nv_bfloat162
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
}
#endif
template
<
>
__device__
__inline__
void
vec_from_smem_transpose
(
float2
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
vec
.
x
=
smem
[
transpose_idx
];
vec
.
y
=
smem
[
smem_pitch
+
transpose_idx
];
}
template
<
typename
Vec_T
,
typename
T
>
__device__
__inline__
void
write_smem_transpose
(
const
Vec_T
&
vec
,
T
*
smem
,
int
transpose_idx
,
int
smem_pitch
);
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
return
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint4
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint64_t
u64
;
uint16_t
u16
[
4
];
}
tmp_1
,
tmp_2
;
union
{
uint4
u32x4
;
uint16_t
u16
[
8
];
}
tmp_3
;
tmp_3
.
u32x4
=
vec
;
tmp_1
.
u16
[
0
]
=
tmp_3
.
u16
[
0
];
tmp_2
.
u16
[
0
]
=
tmp_3
.
u16
[
1
];
tmp_1
.
u16
[
1
]
=
tmp_3
.
u16
[
2
];
tmp_2
.
u16
[
1
]
=
tmp_3
.
u16
[
3
];
tmp_1
.
u16
[
2
]
=
tmp_3
.
u16
[
4
];
tmp_2
.
u16
[
2
]
=
tmp_3
.
u16
[
5
];
tmp_1
.
u16
[
3
]
=
tmp_3
.
u16
[
6
];
tmp_2
.
u16
[
3
]
=
tmp_3
.
u16
[
7
];
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
transpose_idx
])
=
tmp_1
.
u64
;
*
reinterpret_cast
<
uint64_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
])
=
tmp_2
.
u64
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint2
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp_1
,
tmp_2
;
union
{
uint2
u32x2
;
uint16_t
u16
[
4
];
}
tmp_3
;
tmp_3
.
u32x2
=
vec
;
tmp_1
.
u16
[
0
]
=
tmp_3
.
u16
[
0
];
tmp_2
.
u16
[
0
]
=
tmp_3
.
u16
[
1
];
tmp_1
.
u16
[
1
]
=
tmp_3
.
u16
[
2
];
tmp_2
.
u16
[
1
]
=
tmp_3
.
u16
[
3
];
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
transpose_idx
])
=
tmp_1
.
u32
;
*
reinterpret_cast
<
uint32_t
*>
(
&
smem
[
smem_pitch
+
transpose_idx
])
=
tmp_2
.
u32
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint32_t
&
vec
,
uint16_t
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
uint16_t
u16
[
2
];
}
tmp
;
tmp
.
u32
=
vec
;
smem
[
transpose_idx
]
=
tmp
.
u16
[
0
];
smem
[
smem_pitch
+
transpose_idx
]
=
tmp
.
u16
[
1
];
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float4
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
transpose_idx
+
1
]
=
vec
.
z
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
smem
[
smem_pitch
+
transpose_idx
+
1
]
=
vec
.
w
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
uint32_t
&
vec
,
half
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
union
{
uint32_t
u32
;
half
u16
[
2
];
}
tmp
;
tmp
.
u32
=
vec
;
smem
[
transpose_idx
]
=
tmp
.
u16
[
0
];
smem
[
smem_pitch
+
transpose_idx
]
=
tmp
.
u16
[
1
];
}
#ifdef ENABLE_BF16
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
__nv_bfloat162
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
bf16_4_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
write_smem_transpose
(
reinterpret_cast
<
const
uint2
&>
(
vec
),
reinterpret_cast
<
uint16_t
*>
(
smem
),
transpose_idx
,
smem_pitch
);
}
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
bf16_8_t
&
vec
,
__nv_bfloat16
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
write_smem_transpose
(
reinterpret_cast
<
const
uint4
&>
(
vec
),
reinterpret_cast
<
uint16_t
*>
(
smem
),
transpose_idx
,
smem_pitch
);
}
#endif
template
<
>
__device__
__inline__
void
write_smem_transpose
(
const
float2
&
vec
,
float
*
smem
,
int
transpose_idx
,
int
smem_pitch
)
{
smem
[
transpose_idx
]
=
vec
.
x
;
smem
[
smem_pitch
+
transpose_idx
]
=
vec
.
y
;
}
}
// namespace mmha
awq_cuda/attention/ft_attention.cpp
deleted
100644 → 0
View file @
3f10cf1d
// Adapted from NVIDIA/FasterTransformer and FlashAttention
#include <torch/extension.h>
#include "ATen/cuda/CUDAContext.h"
#include <c10/cuda/CUDAGuard.h>
#include "ft_attention.h"
#include "decoder_masked_multihead_attention.h"
#define CHECK_DEVICE(x) TORCH_CHECK(x.device().type() == torch::kCUDA, #x " must be on CUDA")
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define DISPATCH_FLOAT_AND_HALF_AND_BF16(TYPE, NAME, ...) \
if (TYPE == at::ScalarType::Half) { \
using scalar_t = at::Half; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::BFloat16) { \
using scalar_t = at::BFloat16; \
__VA_ARGS__(); \
} else if (TYPE == at::ScalarType::Float) { \
using scalar_t = float; \
__VA_ARGS__(); \
} else { \
AT_ERROR(#NAME, " not implemented for type '", toString(TYPE), "'"); \
}
template
<
typename
T
>
void
masked_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
void
cross_multihead_attention
(
const
Masked_multihead_attention_params
<
T
>&
params
,
const
cudaStream_t
&
stream
);
template
<
typename
T
>
struct
SATypeConverter
{
using
Type
=
T
;
};
template
<
>
struct
SATypeConverter
<
at
::
Half
>
{
using
Type
=
uint16_t
;
};
template
<
>
struct
SATypeConverter
<
at
::
BFloat16
>
{
using
Type
=
__nv_bfloat16
;
};
template
<
typename
T
>
void
set_params
(
Masked_multihead_attention_params
<
T
>
&
params
,
const
size_t
batch_size
,
const
size_t
nheads
,
const
size_t
nheads_kv
,
const
size_t
memory_max_seqlen
,
const
size_t
headdim
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
const
bool
neox_rotary_style
,
const
int
qkv_batch_stride
,
T
*
q_ptr
,
T
*
k_ptr
,
T
*
v_ptr
,
T
*
k_cache_ptr
,
T
*
v_cache_ptr
,
int
*
length_per_sample
,
float
*
alibi_slopes_ptr
,
T
*
out_ptr
)
{
// Reset the parameters
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
q
=
q_ptr
;
params
.
k
=
k_ptr
;
params
.
v
=
v_ptr
;
params
.
q_bias
=
nullptr
;
params
.
k_bias
=
nullptr
;
params
.
v_bias
=
nullptr
;
params
.
k_cache
=
k_cache_ptr
;
params
.
v_cache
=
v_cache_ptr
;
params
.
linear_bias_slopes
=
alibi_slopes_ptr
;
params
.
out
=
out_ptr
;
params
.
cache_indir
=
nullptr
;
params
.
stride
=
qkv_batch_stride
;
params
.
batch_size
=
batch_size
;
params
.
beam_width
=
1
;
params
.
memory_max_len
=
memory_max_seqlen
;
params
.
num_heads
=
nheads
;
params
.
num_kv_heads
=
nheads_kv
;
params
.
hidden_size_per_head
=
headdim
;
params
.
rotary_embedding_dim
=
rotary_embedding_dim
;
params
.
rotary_base
=
rotary_base
;
params
.
neox_rotary_style
=
neox_rotary_style
;
params
.
timestep
=
timestep
;
params
.
inv_sqrt_dh
=
1.
f
/
sqrt
(
float
(
headdim
));
params
.
total_padding_tokens
=
nullptr
;
params
.
masked_tokens
=
nullptr
;
params
.
prefix_prompt_lengths
=
nullptr
;
params
.
max_prefix_prompt_length
=
0
;
params
.
relative_attention_bias
=
nullptr
;
params
.
relative_attention_bias_stride
=
0
;
params
.
cross_attention_out
=
nullptr
;
params
.
max_decoder_seq_len
=
0
;
params
.
is_return_cross_attentions
=
false
;
params
.
finished
=
nullptr
;
params
.
memory_length_per_sample
=
nullptr
;
params
.
length_per_sample
=
length_per_sample
;
}
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
,
const
float
rotary_base
,
// neox_rotary_style = not interleaved
const
bool
neox_rotary_style
)
{
CHECK_DEVICE
(
q
);
CHECK_DEVICE
(
k
);
CHECK_DEVICE
(
v
);
CHECK_DEVICE
(
k_cache
);
CHECK_DEVICE
(
v_cache
);
int
batch_size
=
v_cache
.
size
(
0
);
int
nheads
=
q
.
size
(
1
);
int
nheads_kv
=
v_cache
.
size
(
1
);
int
memory_max_seqlen
=
v_cache
.
size
(
2
);
int
headdim
=
v_cache
.
size
(
3
);
CHECK_SHAPE
(
q
,
batch_size
,
nheads
,
headdim
);
CHECK_SHAPE
(
k
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v
,
batch_size
,
nheads_kv
,
headdim
);
CHECK_SHAPE
(
v_cache
,
batch_size
,
nheads_kv
,
memory_max_seqlen
,
headdim
);
// k_cache shape: [B, H, Dh/x, L, x] where x=8 for fp16 and x=4 for fp32
int
packsize
=
k_cache
.
dtype
()
==
torch
::
kFloat32
?
4
:
8
;
CHECK_SHAPE
(
k_cache
,
batch_size
,
nheads_kv
,
headdim
/
packsize
,
memory_max_seqlen
,
packsize
);
TORCH_CHECK
(
q
.
stride
(
2
)
==
1
&&
q
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
k
.
stride
(
2
)
==
1
&&
k
.
stride
(
1
)
==
headdim
);
TORCH_CHECK
(
v
.
stride
(
2
)
==
1
&&
v
.
stride
(
1
)
==
headdim
);
// TORCH_CHECK(q.stride(0) == k.stride(0) && q.stride(0) == v.stride(0));
CHECK_CONTIGUOUS
(
v_cache
);
CHECK_CONTIGUOUS
(
k_cache
);
if
(
length_per_sample_
.
has_value
())
{
auto
length_per_sample
=
length_per_sample_
.
value
();
CHECK_DEVICE
(
length_per_sample
);
CHECK_SHAPE
(
length_per_sample
,
batch_size
);
CHECK_CONTIGUOUS
(
length_per_sample
);
TORCH_CHECK
(
length_per_sample
.
dtype
()
==
torch
::
kInt32
);
}
if
(
alibi_slopes_
.
has_value
())
{
auto
alibi_slopes
=
alibi_slopes_
.
value
();
CHECK_DEVICE
(
alibi_slopes
);
CHECK_SHAPE
(
alibi_slopes
,
nheads
);
CHECK_CONTIGUOUS
(
alibi_slopes
);
TORCH_CHECK
(
alibi_slopes
.
dtype
()
==
torch
::
kFloat32
);
}
// Otherwise the kernel will be launched from cuda:0 device
// Cast to char to avoid compiler warning about narrowing
at
::
cuda
::
CUDAGuard
device_guard
{(
char
)
q
.
get_device
()};
torch
::
Tensor
out
=
torch
::
empty_like
(
q
);
DISPATCH_FLOAT_AND_HALF_AND_BF16
(
q
.
scalar_type
(),
"single_query_attention"
,
[
&
]
{
using
DataType
=
typename
SATypeConverter
<
scalar_t
>::
Type
;
Masked_multihead_attention_params
<
DataType
>
params
;
set_params
(
params
,
batch_size
,
nheads
,
nheads_kv
,
memory_max_seqlen
,
headdim
,
timestep
,
rotary_embedding_dim
,
rotary_base
,
neox_rotary_style
,
q
.
stride
(
0
),
reinterpret_cast
<
DataType
*>
(
q
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
k_cache
.
data_ptr
()),
reinterpret_cast
<
DataType
*>
(
v_cache
.
data_ptr
()),
length_per_sample_
.
has_value
()
?
length_per_sample_
.
value
().
data_ptr
<
int
>
()
:
nullptr
,
alibi_slopes_
.
has_value
()
?
alibi_slopes_
.
value
().
data_ptr
<
float
>
()
:
nullptr
,
reinterpret_cast
<
DataType
*>
(
out
.
data_ptr
()));
auto
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
masked_multihead_attention
(
params
,
stream
);
});
return
out
;
}
\ No newline at end of file
awq_cuda/attention/ft_attention.h
deleted
100644 → 0
View file @
3f10cf1d
#pragma once
#include <torch/extension.h>
torch
::
Tensor
single_query_attention
(
const
torch
::
Tensor
q
,
const
torch
::
Tensor
k
,
const
torch
::
Tensor
v
,
torch
::
Tensor
k_cache
,
torch
::
Tensor
v_cache
,
c10
::
optional
<
const
torch
::
Tensor
>
length_per_sample_
,
c10
::
optional
<
const
torch
::
Tensor
>
alibi_slopes_
,
const
int
timestep
,
const
int
rotary_embedding_dim
=
0
,
const
float
rotary_base
=
10000
.
0
f
,
const
bool
neox_rotary_style
=
true
);
\ No newline at end of file
awq_cuda/layernorm/layernorm.cu
deleted
100644 → 0
View file @
3f10cf1d
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/layernorm_kernels.cu
*/
#include <torch/extension.h>
#include <cuda_fp16.h>
#include "reduction.cuh"
#include "layernorm.h"
#include <cuda_runtime.h>
#include <c10/cuda/CUDAGuard.h>
static
inline
__device__
float
to_float
(
half
src
)
{
return
__half2float
(
src
);
}
static
inline
__device__
float
to_float
(
float
src
)
{
return
src
;
}
template
<
typename
T
>
__global__
void
generalT5LayerNorm
(
const
T
*
__restrict
input
,
const
T
*
__restrict
gamma
,
T
*
output
,
const
float
layernorm_eps
,
int
m
,
int
n
)
{
// layernorm module in the T5 style No bias and no subtraction of mean.
const
int
tid
=
threadIdx
.
x
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
float
local_var_sum
=
0.0
f
;
for
(
int
i
=
tid
;
i
<
n
;
i
+=
blockDim
.
x
)
{
float
diff
=
to_float
(
__ldg
(
&
input
[
blockIdx
.
x
*
n
+
i
]));
local_var_sum
+=
diff
*
diff
;
}
variance
=
blockReduceSum
(
local_var_sum
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
(
float
)
n
+
layernorm_eps
);
}
__syncthreads
();
for
(
int
i
=
tid
;
i
<
n
;
i
+=
blockDim
.
x
)
{
output
[
blockIdx
.
x
*
n
+
i
]
=
clamp_inf_for_half
<
T
>
((
to_float
(
input
[
blockIdx
.
x
*
n
+
i
])
*
s_variance
)
*
to_float
(
__ldg
(
&
gamma
[
i
])));
}
}
template
<
typename
T
>
void
invokeGeneralT5LayerNorm
(
T
*
out
,
const
T
*
input
,
const
T
*
gamma
,
// const T* beta,
const
float
layernorm_eps
,
const
int
m
,
const
int
n
)
{
dim3
grid
(
m
);
dim3
block
(
min
(
n
,
1024
));
/* For general cases, n is equal to hidden_units, e.g., 512/1024.
Since we have warp shuffle inside the code, block.x % 32 should be 0.
*/
if
(
n
%
32
!=
0
)
{
block
.
x
=
1024
;
}
block
.
x
=
block
.
x
/
(
4
/
sizeof
(
T
));
// if using half, only need half of block.x
/* should pay attention to the rsqrt precision*/
generalT5LayerNorm
<
T
><<<
grid
,
block
>>>
(
input
,
gamma
,
out
,
layernorm_eps
,
m
,
n
);
// For gpt-3
}
template
void
invokeGeneralT5LayerNorm
(
half
*
out
,
const
half
*
input
,
const
half
*
gamma
,
// const half* beta,
const
float
layernorm_eps
,
const
int
m
,
const
int
n
);
template
void
invokeGeneralT5LayerNorm
(
float
*
out
,
const
float
*
input
,
const
float
*
gamma
,
// const half* beta,
const
float
layernorm_eps
,
const
int
m
,
const
int
n
);
// input b, n, c
void
layernorm_forward_cuda
(
torch
::
Tensor
_input
,
torch
::
Tensor
_gamma
,
torch
::
Tensor
_out
,
float
eps
)
{
int
m
=
_input
.
size
(
0
)
*
_input
.
size
(
1
);
int
n
=
_input
.
size
(
2
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_input
));
auto
input
=
reinterpret_cast
<
half
*>
(
_input
.
data_ptr
<
at
::
Half
>
());
auto
gamma
=
reinterpret_cast
<
half
*>
(
_gamma
.
data_ptr
<
at
::
Half
>
());
auto
out
=
reinterpret_cast
<
half
*>
(
_out
.
data_ptr
<
at
::
Half
>
());
invokeGeneralT5LayerNorm
(
out
,
input
,
gamma
,
eps
,
m
,
n
);
}
awq_cuda/layernorm/layernorm.h
deleted
100644 → 0
View file @
3f10cf1d
#include <torch/extension.h>
void
layernorm_forward_cuda
(
torch
::
Tensor
_input
,
torch
::
Tensor
_gamma
,
torch
::
Tensor
_out
,
float
eps
);
awq_cuda/layernorm/reduction.cuh
deleted
100644 → 0
View file @
3f10cf1d
/*
Adapted from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/reduce_kernel_utils.cuh
*/
#pragma once
#include <assert.h>
#if ((__CUDACC_VER_MAJOR__ > 11) || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))
#include <cooperative_groups/reduce.h>
#else
#include <cooperative_groups.h>
#endif
#include <cuda_fp16.h>
#include <cuda_runtime.h>
#include <float.h>
#include <type_traits>
#define HALF_FLT_MAX 65504.F
#define FINAL_MASK 0xffffffff
template
<
typename
T
>
inline
__device__
T
add
(
T
a
,
T
b
)
{
return
a
+
b
;
}
template
<
>
inline
__device__
half2
add
(
half2
a
,
half2
b
)
{
return
__hadd2
(
a
,
b
);
}
template
<
>
inline
__device__
half
add
(
half
a
,
half
b
)
{
return
__hadd
(
a
,
b
);
}
template
<
typename
T
>
__inline__
__device__
T
warpReduceSum
(
T
val
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
=
add
(
val
,
__shfl_xor_sync
(
FINAL_MASK
,
val
,
mask
,
32
));
//__shfl_sync bf16 return float when sm < 80
return
val
;
}
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
blockReduceSum
(
T
val
)
{
static
__shared__
T
shared
[
32
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
warpReduceSum
<
T
>
(
val
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
val
=
(
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
))
?
shared
[
lane
]
:
(
T
)(
0.0
f
);
val
=
warpReduceSum
<
T
>
(
val
);
return
val
;
}
template
<
typename
T
>
__device__
__forceinline__
T
clamp_inf_for_half
(
const
float
input
)
{
return
input
;
}
template
<
>
__device__
__forceinline__
half
clamp_inf_for_half
(
const
float
input
)
{
// clamp inf values to enable fp16 training
return
input
>
0.0
f
?
__float2half
(
min
(
input
,
HALF_FLT_MAX
-
1000
))
:
__float2half
(
max
(
input
,
-
HALF_FLT_MAX
+
1000
));
}
awq_cuda/position_embedding/pos_encoding.h
deleted
100644 → 0
View file @
3f10cf1d
#pragma once
#include <torch/extension.h>
void
rotary_embedding_neox
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
);
\ No newline at end of file
awq_cuda/position_embedding/pos_encoding_kernels.cu
deleted
100644 → 0
View file @
3f10cf1d
/*
Adapted from the VLLM project:
https://github.com/vllm-project/vllm/blob/main/csrc/pos_encoding_kernels.cu
*/
#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include "pos_encoding.h"
template
<
typename
scalar_t
>
__global__
void
rotary_embedding_neox_kernel
(
const
int64_t
*
__restrict__
positions
,
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [num_tokens, num_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int
stride
,
const
int
num_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
int
embed_dim
=
rot_dim
/
2
;
const
int
n
=
num_heads
*
embed_dim
;
for
(
int
i
=
threadIdx
.
x
;
i
<
n
;
i
+=
blockDim
.
x
)
{
const
int
head_idx
=
i
/
embed_dim
;
const
int
token_head
=
token_idx
*
stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
x_index
=
rot_offset
;
const
int
y_index
=
embed_dim
+
rot_offset
;
const
int
out_x
=
token_idx
*
stride
+
head_idx
*
head_size
+
x_index
;
const
int
out_y
=
token_idx
*
stride
+
head_idx
*
head_size
+
y_index
;
const
scalar_t
cos
=
__ldg
(
cache_ptr
+
x_index
);
const
scalar_t
sin
=
__ldg
(
cache_ptr
+
y_index
);
const
scalar_t
q_x
=
query
[
token_head
+
x_index
];
const
scalar_t
q_y
=
query
[
token_head
+
y_index
];
query
[
out_x
]
=
q_x
*
cos
-
q_y
*
sin
;
query
[
out_y
]
=
q_y
*
cos
+
q_x
*
sin
;
const
scalar_t
k_x
=
key
[
token_head
+
x_index
];
const
scalar_t
k_y
=
key
[
token_head
+
y_index
];
key
[
out_x
]
=
k_x
*
cos
-
k_y
*
sin
;
key
[
out_y
]
=
k_y
*
cos
+
k_x
*
sin
;
}
}
void
rotary_embedding_neox
(
torch
::
Tensor
&
positions
,
// [b, num_tokens]
torch
::
Tensor
&
query
,
// [b, num_tokens, 1, num_heads, head_size]
torch
::
Tensor
&
key
,
// [b, num_tokens, 1, num_heads, head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
)
// [max_position, rot_dim]
{
int
num_tokens
=
query
.
size
(
0
)
*
query
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
2
);
int
stride
=
num_heads
*
head_size
;
// TORCH_CHECK(stride == key.stride(0));
dim3
grid
(
num_tokens
);
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
AT_DISPATCH_FLOATING_TYPES_AND2
(
at
::
ScalarType
::
Half
,
at
::
ScalarType
::
BFloat16
,
query
.
scalar_type
(),
"rotary_embedding_neox"
,
[
&
]
{
rotary_embedding_neox_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
stride
,
num_heads
,
head_size
);
});
}
awq_cuda/pybind_awq.cpp
deleted
100644 → 0
View file @
3f10cf1d
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "layernorm/layernorm.h"
#include "quantization/gemm_cuda.h"
#include "quantization/gemv_cuda.h"
#include "position_embedding/pos_encoding.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"layernorm_forward_cuda"
,
&
layernorm_forward_cuda
,
"FasterTransformer layernorm kernel"
);
m
.
def
(
"gemm_forward_cuda"
,
&
gemm_forward_cuda
,
"Quantized GEMM kernel."
);
m
.
def
(
"gemmv2_forward_cuda"
,
&
gemmv2_forward_cuda
,
"Quantized v2 GEMM kernel."
);
m
.
def
(
"gemv_forward_cuda"
,
&
gemv_forward_cuda
,
"Quantized GEMV kernel."
);
m
.
def
(
"rotary_embedding_neox"
,
&
rotary_embedding_neox
,
"Apply GPT-NeoX style rotary embedding to query and key"
);
}
\ No newline at end of file
awq_cuda/pybind_ft.cpp
deleted
100644 → 0
View file @
3f10cf1d
#include <pybind11/pybind11.h>
#include <torch/extension.h>
#include "attention/ft_attention.h"
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"single_query_attention"
,
&
single_query_attention
,
"Attention with a single query"
,
py
::
arg
(
"q"
),
py
::
arg
(
"k"
),
py
::
arg
(
"v"
),
py
::
arg
(
"k_cache"
),
py
::
arg
(
"v_cache"
),
py
::
arg
(
"length_per_sample_"
),
py
::
arg
(
"alibi_slopes_"
),
py
::
arg
(
"timestep"
),
py
::
arg
(
"rotary_embedding_dim"
)
=
0
,
py
::
arg
(
"rotary_base"
)
=
10000.0
f
,
py
::
arg
(
"neox_rotary_style"
)
=
true
);
}
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment