Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
6e58fced
Unverified
Commit
6e58fced
authored
Jul 03, 2023
by
tpoisonooo
Committed by
GitHub
Jul 03, 2023
Browse files
fix(kernel): speed degrade (#41)
* feat(template): remote diff * feat(cmake): use c++17
parent
8aa6eb10
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
68 additions
and
66 deletions
+68
-66
CMakeLists.txt
CMakeLists.txt
+1
-1
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
...ihead_attention/decoder_masked_multihead_attention_128.cu
+25
-13
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
...attention/decoder_masked_multihead_attention_template.cuh
+42
-52
No files found.
CMakeLists.txt
View file @
6e58fced
...
@@ -84,7 +84,7 @@ if(USE_TRITONSERVER_DATATYPE)
...
@@ -84,7 +84,7 @@ if(USE_TRITONSERVER_DATATYPE)
add_definitions
(
"-DUSE_TRITONSERVER_DATATYPE"
)
add_definitions
(
"-DUSE_TRITONSERVER_DATATYPE"
)
endif
()
endif
()
set
(
CXX_STD
"1
4
"
CACHE STRING
"C++ standard"
)
set
(
CXX_STD
"1
7
"
CACHE STRING
"C++ standard"
)
set
(
CUDA_PATH
${
CUDA_TOOLKIT_ROOT_DIR
}
)
set
(
CUDA_PATH
${
CUDA_TOOLKIT_ROOT_DIR
}
)
...
...
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
View file @
6e58fced
...
@@ -26,10 +26,10 @@
...
@@ -26,10 +26,10 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, stream)
\
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS,
QUANT_POLICY,
stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK);
\
dim3 grid(params.num_heads, params.batch_size); \
dim3 grid(params.num_heads, params.batch_size);
\
mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS>
\
mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS
, QUANT_POLICY
> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
////////////////////////////////////////////////////////////////////////////////////////////////////
...
@@ -40,18 +40,30 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st
...
@@ -40,18 +40,30 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st
{
{
constexpr
int
THREADS_PER_VALUE
=
threads_per_value_t
<
T
,
Dh_MAX
>::
value
;
constexpr
int
THREADS_PER_VALUE
=
threads_per_value_t
<
T
,
Dh_MAX
>::
value
;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
// constexpr bool DO_CROSS_ATTENTION = std::is_same<KERNEL_PARAMS_TYPE, Cross_multihead_attention_params<T>>::value;
int
tlength
=
params
.
timestep
;
const
int
tlength
=
params
.
timestep
;
FT_CHECK
(
params
.
cache_indir
==
nullptr
);
FT_CHECK
(
params
.
cache_indir
==
nullptr
);
if
(
tlength
<
32
)
{
if
(
params
.
int8_mode
==
4
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
false
,
stream
);
if
(
tlength
<
32
)
{
}
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
false
,
4
,
stream
);
else
if
(
tlength
<
2048
)
{
}
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
false
,
stream
);
else
if
(
tlength
<
2048
)
{
}
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
false
,
4
,
stream
);
else
{
}
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
false
,
stream
);
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
false
,
4
,
stream
);
}
}
else
{
if
(
tlength
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
false
,
0
,
stream
);
}
else
if
(
tlength
<
2048
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
2
,
THREADS_PER_VALUE
,
128
,
false
,
0
,
stream
);
}
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
false
,
0
,
stream
);
}
}
}
}
}
...
...
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
View file @
6e58fced
...
@@ -17,9 +17,8 @@
...
@@ -17,9 +17,8 @@
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/turbomind/kernels/decoder_masked_multihead_attention_utils.h"
#include "src/turbomind/models/llama/llama_utils.h"
// #include "src/turbomind/utils/cuda_bf16_wrapper.h"
#include "src/turbomind/utils/cuda_bf16_wrapper.h"
// #include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_fp8_utils.h"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include "src/turbomind/utils/cuda_type_utils.cuh"
#include <assert.h>
#include <assert.h>
#include <float.h>
#include <float.h>
...
@@ -1272,7 +1271,8 @@ template<typename T, // The type of the inputs. Supported types: float and half
...
@@ -1272,7 +1271,8 @@ template<typename T, // The type of the inputs. Supported types: float and half
int
THREADS_PER_KEY
,
// The number of threads per key.
int
THREADS_PER_KEY
,
// The number of threads per key.
int
THREADS_PER_VALUE
,
// The number of threads per value.
int
THREADS_PER_VALUE
,
// The number of threads per value.
int
THREADS_PER_BLOCK
,
// The number of threads in a threadblock.
int
THREADS_PER_BLOCK
,
// The number of threads in a threadblock.
bool
HAS_BEAMS
>
bool
HAS_BEAMS
,
int
QUANT_POLICY
>
// quantization method
__global__
void
masked_multihead_attention_kernel
(
Multihead_attention_params
<
T
>
params
)
__global__
void
masked_multihead_attention_kernel
(
Multihead_attention_params
<
T
>
params
)
{
{
...
@@ -1462,16 +1462,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1462,16 +1462,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int
offset
=
bhi
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
int
offset
=
bhi
*
params
.
memory_max_len
*
Dh
+
co
*
params
.
memory_max_len
*
QK_ELTS_IN_16B
+
tlength_circ
*
QK_ELTS_IN_16B
+
ci
;
+
tlength_circ
*
QK_ELTS_IN_16B
+
ci
;
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
*
reinterpret_cast
<
Qk_vec_m
*>
(
&
params
.
k_cache
[
offset
])
=
vec_conversion
<
Qk_vec_m
,
Qk_vec_k
>
(
k
);
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec_k
>::
value
>::
type
;
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec_k
>::
value
>::
type
;
Packed_Int8_t
k_int8
=
quant
(
k
,
k_scale
);
Packed_Int8_t
k_int8
=
quant
(
k
,
k_scale
);
int8_t
*
dst_ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache
);
int8_t
*
dst_ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache
);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
dst_ptr
[
offset
])
=
k_int8
;
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
dst_ptr
[
offset
])
=
k_int8
;
}
}
else
{
*
reinterpret_cast
<
Qk_vec_m
*>
(
&
params
.
k_cache
[
offset
])
=
vec_conversion
<
Qk_vec_m
,
Qk_vec_k
>
(
k
);
}
}
}
else
{
else
{
int
offset
;
int
offset
;
...
@@ -1484,17 +1483,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1484,17 +1483,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
+
co
*
QK_ELTS_IN_16B
+
ci
;
+
co
*
QK_ELTS_IN_16B
+
ci
;
}
}
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
*
reinterpret_cast
<
Qk_vec_m
*>
(
&
params
.
k_cache_per_sample
[
bi
][
offset
])
=
vec_conversion
<
Qk_vec_m
,
Qk_vec_k
>
(
k
);
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec_k
>::
value
>::
type
;
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec_k
>::
value
>::
type
;
Packed_Int8_t
k_int8
=
quant
(
k
,
k_scale
);
Packed_Int8_t
k_int8
=
quant
(
k
,
k_scale
);
int8_t
*
dst_ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache_per_sample
[
bi
]);
int8_t
*
dst_ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache_per_sample
[
bi
]);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
dst_ptr
[
offset
])
=
k_int8
;
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
dst_ptr
[
offset
])
=
k_int8
;
}
}
else
{
*
reinterpret_cast
<
Qk_vec_m
*>
(
&
params
.
k_cache_per_sample
[
bi
][
offset
])
=
vec_conversion
<
Qk_vec_m
,
Qk_vec_k
>
(
k
);
}
}
}
}
}
}
}
...
@@ -1575,7 +1573,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1575,7 +1573,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
T
*
k_cache_batch
=
nullptr
;
T
*
k_cache_batch
=
nullptr
;
int8_t
*
k_cache_batch_int8
=
nullptr
;
int8_t
*
k_cache_batch_int8
=
nullptr
;
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
k_cache_batch
=
params
.
k_cache_per_sample
?
(
params
.
k_cache_per_sample
[
bi
]
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
ki
)
:
&
params
.
k_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki];
}
else
if
(
QUANT_POLICY
==
4
)
{
// convert k_cache_per_sample to int8
// convert k_cache_per_sample to int8
if
(
params
.
k_cache_per_sample
)
{
if
(
params
.
k_cache_per_sample
)
{
int8_t
*
ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache_per_sample
[
bi
]);
int8_t
*
ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache_per_sample
[
bi
]);
...
@@ -1586,14 +1590,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1586,14 +1590,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
k_cache_batch_int8
=
&
ptr
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
k_cache_batch_int8
=
&
ptr
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
}
}
}
}
else
{
T
*
k_cache
=
params
.
k_cache_per_sample
?
(
params
.
k_cache_per_sample
[
bi
]
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
ki
)
:
&
params
.
k_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki];
k_cache_batch
=
k_cache
;
}
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// Pick a number of keys to make sure all the threads of a warp enter (due to shfl_sync).
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
// int ti_end = div_up(params.timestep, K_PER_WARP) * K_PER_WARP;
...
@@ -1629,7 +1625,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1629,7 +1625,10 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
beam_offset
=
beam_indices
[
ti_circ
]
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
beam_offset
=
beam_indices
[
ti_circ
]
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
;
}
}
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
k
[
ii
]
=
vec_conversion
<
K_vec_k
,
K_vec_m
>
(
(
*
reinterpret_cast
<
const
K_vec_m
*>
(
&
k_cache_batch
[
beam_offset
+
jj
*
QK_ELTS_IN_16B
])));
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
K_vec_m
>::
value
>::
type
;
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
K_vec_m
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
K_vec_m
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
K_vec_m
>::
value
>::
type
;
...
@@ -1639,10 +1638,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1639,10 +1638,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
k
[
ii
]
=
vec_conversion
<
K_vec_k
,
Packed_Float_t
>
(
k_vec_m_float
);
k
[
ii
]
=
vec_conversion
<
K_vec_k
,
Packed_Float_t
>
(
k_vec_m_float
);
}
}
else
{
k
[
ii
]
=
vec_conversion
<
K_vec_k
,
K_vec_m
>
(
(
*
reinterpret_cast
<
const
K_vec_m
*>
(
&
k_cache_batch
[
beam_offset
+
jj
*
QK_ELTS_IN_16B
])));
}
}
}
}
}
}
}
...
@@ -1763,7 +1758,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1763,7 +1758,15 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
int8_t
*
v_cache_int8
=
nullptr
;
int8_t
*
v_cache_int8
=
nullptr
;
int8_t
*
v_cache_batch_int8
=
nullptr
;
int8_t
*
v_cache_batch_int8
=
nullptr
;
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
v_cache
=
params
.
v_cache_per_sample
?
(
params
.
v_cache_per_sample
[
bi
]
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
vi
)
:
&
params
.
v_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
vi
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch
=
v_cache
;
}
else
if
(
QUANT_POLICY
==
4
)
{
if
(
params
.
v_cache_per_sample
)
{
if
(
params
.
v_cache_per_sample
)
{
int8_t
*
ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
v_cache_per_sample
[
bi
]);
int8_t
*
ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
v_cache_per_sample
[
bi
]);
v_cache_int8
=
ptr
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
vi
;
v_cache_int8
=
ptr
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
vi
;
...
@@ -1775,15 +1778,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1775,15 +1778,6 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
v_cache_batch_int8
=
v_cache_int8
;
v_cache_batch_int8
=
v_cache_int8
;
}
}
else
{
v_cache
=
params
.
v_cache_per_sample
?
(
params
.
v_cache_per_sample
[
bi
]
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
vi
)
:
&
params
.
v_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
vi
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch
=
v_cache
;
}
// The number of values processed per iteration of the loop.
// The number of values processed per iteration of the loop.
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
constexpr
int
V_PER_ITER
=
THREADS_PER_BLOCK
/
THREADS_PER_VALUE
;
...
@@ -1834,17 +1828,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1834,17 +1828,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Load the values from the cache.
// Load the values from the cache.
V_vec_k
v
;
V_vec_k
v
;
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
v
=
vec_conversion
<
V_vec_k
,
V_vec_m
>
(
*
reinterpret_cast
<
const
V_vec_m
*>
(
&
v_cache_batch
[
beam_offset
+
ti
*
Dh
]));
}
else
if
(
QUANT_POLICY
==
4
)
{
Packed_Int8_t
v_vec_m_int8
=
Packed_Int8_t
v_vec_m_int8
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
v_cache_batch_int8
[
beam_offset
+
ti
*
Dh
]);
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
v_cache_batch_int8
[
beam_offset
+
ti
*
Dh
]);
Packed_Float_t
v_vec_m_float
=
dequant
(
v_vec_m_int8
,
v_scale
);
Packed_Float_t
v_vec_m_float
=
dequant
(
v_vec_m_int8
,
v_scale
);
v
=
vec_conversion
<
V_vec_k
,
Packed_Float_t
>
(
v_vec_m_float
);
v
=
vec_conversion
<
V_vec_k
,
Packed_Float_t
>
(
v_vec_m_float
);
}
}
else
{
v
=
vec_conversion
<
V_vec_k
,
V_vec_m
>
(
*
reinterpret_cast
<
const
V_vec_m
*>
(
&
v_cache_batch
[
beam_offset
+
ti
*
Dh
]));
}
// Load the logits from shared memory.
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...
@@ -1881,18 +1874,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1881,18 +1874,16 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
const
int
beam_offset
=
HAS_BEAMS
?
beam_src
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
:
0
;
const
int
beam_offset
=
HAS_BEAMS
?
beam_src
*
params
.
num_heads
*
params
.
memory_max_len
*
Dh
:
0
;
// Load the values from the cache.
// Load the values from the cache.
V_vec_k
v
;
V_vec_k
v
;
if
(
not
QUANT_POLICY
)
{
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
v
=
vec_conversion
<
V_vec_k
,
V_vec_m
>
(
*
reinterpret_cast
<
const
V_vec_m
*>
(
&
v_cache_batch
[
beam_offset
+
ti_circ
*
Dh
]));
}
else
if
(
QUANT_POLICY
==
4
)
{
Packed_Int8_t
v_vec_m_int8
=
Packed_Int8_t
v_vec_m_int8
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
v_cache_batch_int8
[
beam_offset
+
ti_circ
*
Dh
]);
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
v_cache_batch_int8
[
beam_offset
+
ti_circ
*
Dh
]);
Packed_Float_t
v_vec_m_float
=
dequant
(
v_vec_m_int8
,
v_scale
);
Packed_Float_t
v_vec_m_float
=
dequant
(
v_vec_m_int8
,
v_scale
);
v
=
vec_conversion
<
V_vec_k
,
Packed_Float_t
>
(
v_vec_m_float
);
v
=
vec_conversion
<
V_vec_k
,
Packed_Float_t
>
(
v_vec_m_float
);
}
}
else
{
v
=
vec_conversion
<
V_vec_k
,
V_vec_m
>
(
*
reinterpret_cast
<
const
V_vec_m
*>
(
&
v_cache_batch
[
beam_offset
+
ti_circ
*
Dh
]));
}
// Load the logits from shared memory.
// Load the logits from shared memory.
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
#if defined(MMHA_USE_FP32_ACUM_FOR_LOGITS)
...
@@ -1938,14 +1929,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
...
@@ -1938,14 +1929,13 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Store the values with bias back to global memory in the cache for V.
// Store the values with bias back to global memory in the cache for V.
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
//*reinterpret_cast<V_vec_k*>(&v_cache[params.timestep*Dh]) = v;
if
(
params
.
int8_mode
&
QuantPolicy
::
kCacheKVInt8
)
{
if
(
not
QUANT_POLICY
)
{
*
reinterpret_cast
<
V_vec_m
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
vec_conversion
<
V_vec_m
,
V_vec_k
>
(
v
);
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec_k
>::
value
>::
type
;
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec_k
>::
value
>::
type
;
Packed_Int8_t
v_int8
=
quant
(
v
,
v_scale
);
Packed_Int8_t
v_int8
=
quant
(
v
,
v_scale
);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
v_cache_int8
[
tlength_circ
*
Dh
])
=
v_int8
;
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
v_cache_int8
[
tlength_circ
*
Dh
])
=
v_int8
;
}
}
else
{
*
reinterpret_cast
<
V_vec_m
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
vec_conversion
<
V_vec_m
,
V_vec_k
>
(
v
);
}
}
}
// Initialize the output value with the current timestep.
// Initialize the output value with the current timestep.
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment