Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
Lmdeploy
Commits
208b6841
"git@developer.sourcefind.cn:OpenDAS/mmcv.git" did not exist on "e322848e3a466f2b1f5b2e1a9cb5552342db1e93"
Unverified
Commit
208b6841
authored
Jul 06, 2023
by
AllentDan
Committed by
GitHub
Jul 06, 2023
Browse files
fix clang-format (#68)
parent
b2393467
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
65 additions
and
46 deletions
+65
-46
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
...ihead_attention/decoder_masked_multihead_attention_128.cu
+14
-6
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
...attention/decoder_masked_multihead_attention_template.cuh
+19
-11
src/turbomind/python/bind.cpp
src/turbomind/python/bind.cpp
+26
-26
src/turbomind/python/dlpack.h
src/turbomind/python/dlpack.h
+6
-3
No files found.
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_128.cu
View file @
208b6841
...
...
@@ -26,11 +26,18 @@
////////////////////////////////////////////////////////////////////////////////////////////////////
#define MMHA_LAUNCH_KERNEL(T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY> \
<<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
#define MMHA_LAUNCH_KERNEL( \
T, Dh, Dh_MAX, THDS_PER_KEY, THDS_PER_VALUE, THDS_PER_BLOCK, HAS_BEAMS, QUANT_POLICY, stream) \
size_t smem_sz = mmha::smem_size_in_bytes<T>(params, THDS_PER_VALUE, THDS_PER_BLOCK); \
dim3 grid(params.num_heads, params.batch_size); \
mmha::masked_multihead_attention_kernel<T, \
Dh, \
Dh_MAX, \
THDS_PER_KEY, \
THDS_PER_VALUE, \
THDS_PER_BLOCK, \
HAS_BEAMS, \
QUANT_POLICY><<<grid, THDS_PER_BLOCK, smem_sz, stream>>>(params)
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -54,7 +61,8 @@ void mmha_launch_kernel(const KERNEL_PARAMS_TYPE& params, const cudaStream_t& st
else
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
1
,
THREADS_PER_VALUE
,
256
,
false
,
4
,
stream
);
}
}
else
{
}
else
{
if
(
tlength
<
32
)
{
MMHA_LAUNCH_KERNEL
(
T
,
Dh
,
Dh_MAX
,
4
,
THREADS_PER_VALUE
,
64
,
false
,
0
,
stream
);
}
...
...
src/turbomind/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.cuh
View file @
208b6841
...
...
@@ -1272,7 +1272,7 @@ template<typename T, // The type of the inputs. Supported types: float and half
int
THREADS_PER_VALUE
,
// The number of threads per value.
int
THREADS_PER_BLOCK
,
// The number of threads in a threadblock.
bool
HAS_BEAMS
,
int
QUANT_POLICY
>
// quantization method
int
QUANT_POLICY
>
// quantization method
__global__
void
masked_multihead_attention_kernel
(
Multihead_attention_params
<
T
>
params
)
{
...
...
@@ -1464,7 +1464,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
*
reinterpret_cast
<
Qk_vec_m
*>
(
&
params
.
k_cache
[
offset
])
=
vec_conversion
<
Qk_vec_m
,
Qk_vec_k
>
(
k
);
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec_k
>::
value
>::
type
;
Packed_Int8_t
k_int8
=
quant
(
k
,
k_scale
);
...
...
@@ -1486,7 +1487,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
*
reinterpret_cast
<
Qk_vec_m
*>
(
&
params
.
k_cache_per_sample
[
bi
][
offset
])
=
vec_conversion
<
Qk_vec_m
,
Qk_vec_k
>
(
k
);
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
Qk_vec_k
>::
value
>::
type
;
Packed_Int8_t
k_int8
=
quant
(
k
,
k_scale
);
...
...
@@ -1575,11 +1577,12 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
k_cache_batch
=
params
.
k_cache_per_sample
?
(
params
.
k_cache_per_sample
[
bi
]
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
ki
)
:
&
params
.
k_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
+
hi
*
params
.
memory_max_len
*
Dh
+
ki
)
:
&
params
.
k_cache
[
bhi
*
params
.
memory_max_len
*
Dh
+
ki
];
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* k_cache_batch = ¶ms.k_cache[bbhi * params.memory_max_len * Dh + ki];
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
// convert k_cache_per_sample to int8
if
(
params
.
k_cache_per_sample
)
{
int8_t
*
ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
k_cache_per_sample
[
bi
]);
...
...
@@ -1628,7 +1631,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
k
[
ii
]
=
vec_conversion
<
K_vec_k
,
K_vec_m
>
(
(
*
reinterpret_cast
<
const
K_vec_m
*>
(
&
k_cache_batch
[
beam_offset
+
jj
*
QK_ELTS_IN_16B
])));
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
K_vec_m
>::
value
>::
type
;
using
Packed_Float_t
=
typename
packed_type
<
float
,
num_elems
<
K_vec_m
>::
value
>::
type
;
...
...
@@ -1766,7 +1770,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
// Base pointer for the beam's batch, before offsetting with indirection buffer
// T* v_cache_batch = ¶ms.v_cache[bbhi * params.memory_max_len * Dh + vi];
v_cache_batch
=
v_cache
;
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
if
(
params
.
v_cache_per_sample
)
{
int8_t
*
ptr
=
reinterpret_cast
<
int8_t
*>
(
params
.
v_cache_per_sample
[
bi
]);
v_cache_int8
=
ptr
+
params
.
kv_cache_per_sample_offset
+
hi
*
params
.
memory_max_len
*
Dh
+
vi
;
...
...
@@ -1831,7 +1836,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
v
=
vec_conversion
<
V_vec_k
,
V_vec_m
>
(
*
reinterpret_cast
<
const
V_vec_m
*>
(
&
v_cache_batch
[
beam_offset
+
ti
*
Dh
]));
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
Packed_Int8_t
v_vec_m_int8
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
v_cache_batch_int8
[
beam_offset
+
ti
*
Dh
]);
Packed_Float_t
v_vec_m_float
=
dequant
(
v_vec_m_int8
,
v_scale
);
...
...
@@ -1877,7 +1883,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
v
=
vec_conversion
<
V_vec_k
,
V_vec_m
>
(
*
reinterpret_cast
<
const
V_vec_m
*>
(
&
v_cache_batch
[
beam_offset
+
ti_circ
*
Dh
]));
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
Packed_Int8_t
v_vec_m_int8
=
*
reinterpret_cast
<
const
Packed_Int8_t
*>
(
&
v_cache_batch_int8
[
beam_offset
+
ti_circ
*
Dh
]);
Packed_Float_t
v_vec_m_float
=
dequant
(
v_vec_m_int8
,
v_scale
);
...
...
@@ -1931,7 +1938,8 @@ __global__ void masked_multihead_attention_kernel(Multihead_attention_params<T>
if
(
not
QUANT_POLICY
)
{
*
reinterpret_cast
<
V_vec_m
*>
(
&
v_cache
[
tlength_circ
*
Dh
])
=
vec_conversion
<
V_vec_m
,
V_vec_k
>
(
v
);
}
else
if
(
QUANT_POLICY
==
4
)
{
}
else
if
(
QUANT_POLICY
==
4
)
{
using
Packed_Int8_t
=
typename
packed_type
<
int8_t
,
num_elems
<
V_vec_k
>::
value
>::
type
;
Packed_Int8_t
v_int8
=
quant
(
v
,
v_scale
);
*
reinterpret_cast
<
Packed_Int8_t
*>
(
&
v_cache_int8
[
tlength_circ
*
Dh
])
=
v_int8
;
...
...
src/turbomind/python/bind.cpp
View file @
208b6841
#include "src/turbomind/python/dlpack.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include "src/turbomind/triton_backend/llama/LlamaTritonModel.h"
#include "src/turbomind/triton_backend/transformer_triton_backend.hpp"
#include <memory>
#include <pybind11/pybind11.h>
#include <pybind11/stl.h>
...
...
@@ -302,37 +302,37 @@ PYBIND11_MODULE(_turbomind, m)
py
::
class_
<
AbstractTransformerModelInstance
>
(
m
,
"AbstractTransformerModelInstance"
)
.
def
(
"forward"
,
[](
AbstractTransformerModelInstance
*
model
,
std
::
shared_ptr
<
TensorMap
>
input_tensors
,
ft
::
AbstractInstanceComm
*
inst_comm
)
{
return
model
->
forward
(
input_tensors
,
inst_comm
);
},
py
::
call_guard
<
py
::
gil_scoped_release
>
(),
[](
AbstractTransformerModelInstance
*
model
,
std
::
shared_ptr
<
TensorMap
>
input_tensors
,
ft
::
AbstractInstanceComm
*
inst_comm
)
{
return
model
->
forward
(
input_tensors
,
inst_comm
);
},
py
::
call_guard
<
py
::
gil_scoped_release
>
(),
"input_tensors"
_a
,
"inst_comm"
_a
=
nullptr
);
// transformer model
py
::
class_
<
AbstractTransformerModel
,
std
::
shared_ptr
<
AbstractTransformerModel
>>
(
m
,
"AbstractTransformerModel"
)
// .def_static("create_llama_model", &AbstractTransformerModel::createLlamaModel, "model_dir"_a)
.
def_static
(
"create_llama_model"
,
[](
std
::
string
model_dir
,
size_t
tensor_para_size
,
size_t
pipeline_para_size
,
int
enable_custom_all_reduce
,
std
::
string
data_type
)
->
std
::
shared_ptr
<
AbstractTransformerModel
>
{
if
(
data_type
==
"half"
||
data_type
==
"fp16"
)
{
return
std
::
make_shared
<
LlamaTritonModel
<
half
>>
(
tensor_para_size
,
pipeline_para_size
,
enable_custom_all_reduce
,
model_dir
);
}
else
{
return
std
::
make_shared
<
LlamaTritonModel
<
float
>>
(
tensor_para_size
,
pipeline_para_size
,
enable_custom_all_reduce
,
model_dir
);
}
},
"model_dir"
_a
,
"tensor_para_size"
_a
=
1
,
"pipeline_para_size"
_a
=
1
,
"enable_custom_all_reduce"
_a
=
0
,
"data_type"
_a
=
"half"
)
.
def_static
(
"create_llama_model"
,
[](
std
::
string
model_dir
,
size_t
tensor_para_size
,
size_t
pipeline_para_size
,
int
enable_custom_all_reduce
,
std
::
string
data_type
)
->
std
::
shared_ptr
<
AbstractTransformerModel
>
{
if
(
data_type
==
"half"
||
data_type
==
"fp16"
)
{
return
std
::
make_shared
<
LlamaTritonModel
<
half
>>
(
tensor_para_size
,
pipeline_para_size
,
enable_custom_all_reduce
,
model_dir
);
}
else
{
return
std
::
make_shared
<
LlamaTritonModel
<
float
>>
(
tensor_para_size
,
pipeline_para_size
,
enable_custom_all_reduce
,
model_dir
);
}
},
"model_dir"
_a
,
"tensor_para_size"
_a
=
1
,
"pipeline_para_size"
_a
=
1
,
"enable_custom_all_reduce"
_a
=
0
,
"data_type"
_a
=
"half"
)
.
def
(
"create_nccl_params"
,
&
AbstractTransformerModel
::
createNcclParams
,
"node_id"
_a
,
...
...
src/turbomind/python/dlpack.h
View file @
208b6841
...
...
@@ -69,9 +69,11 @@ typedef struct {
* \brief The device type in DLDevice.
*/
#ifdef __cplusplus
typedef
enum
:
int32_t
{
typedef
enum
:
int32_t
{
#else
typedef
enum
{
typedef
enum
{
#endif
/*! \brief CPU device */
kDLCPU
=
1
,
...
...
@@ -134,7 +136,8 @@ typedef struct {
/*!
* \brief The type code options DLDataType.
*/
typedef
enum
{
typedef
enum
{
/*! \brief signed integer */
kDLInt
=
0U
,
/*! \brief unsigned integer */
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment