Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
b9e12416
Commit
b9e12416
authored
May 31, 2024
by
zhuwenwen
Browse files
merge v0.4.3
parents
e5d707db
e9d3aa04
Changes
345
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
845 additions
and
815 deletions
+845
-815
csrc/cpu/layernorm.cpp
csrc/cpu/layernorm.cpp
+16
-16
csrc/cpu/pos_encoding.cpp
csrc/cpu/pos_encoding.cpp
+33
-34
csrc/cpu/pybind.cpp
csrc/cpu/pybind.cpp
+23
-52
csrc/cuda_compat.h
csrc/cuda_compat.h
+10
-3
csrc/cuda_utils.h
csrc/cuda_utils.h
+2
-5
csrc/cuda_utils_kernels.cu
csrc/cuda_utils_kernels.cu
+17
-23
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+27
-28
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+51
-54
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+19
-19
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+20
-22
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+121
-121
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+2
-1
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+3
-5
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+110
-101
csrc/ops.h
csrc/ops.h
+123
-188
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+103
-126
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+6
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+154
-0
csrc/punica/bgmv/vec_dtypes.cuh
csrc/punica/bgmv/vec_dtypes.cuh
+3
-2
csrc/punica/punica_ops.cu
csrc/punica/punica_ops.cu
+2
-15
No files found.
Too many changes to show.
To preserve performance only
345 of 345+
files are displayed.
Plain diff
Email patch
csrc/cpu/layernorm.cpp
View file @
b9e12416
...
@@ -2,10 +2,10 @@
...
@@ -2,10 +2,10 @@
namespace
{
namespace
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
rms_norm_impl
(
scalar_t
*
__restrict__
out
,
void
rms_norm_impl
(
scalar_t
*
__restrict__
out
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
input
,
const
scalar_t
*
__restrict__
weight
,
const
float
epsilon
,
const
scalar_t
*
__restrict__
weight
,
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
num_tokens
,
const
int
hidden_size
)
{
using
scalar_vec_t
=
vec_op
::
vec_t
<
scalar_t
>
;
using
scalar_vec_t
=
vec_op
::
vec_t
<
scalar_t
>
;
constexpr
int
VEC_ELEM_NUM
=
scalar_vec_t
::
get_elem_num
();
constexpr
int
VEC_ELEM_NUM
=
scalar_vec_t
::
get_elem_num
();
TORCH_CHECK
(
hidden_size
%
VEC_ELEM_NUM
==
0
);
TORCH_CHECK
(
hidden_size
%
VEC_ELEM_NUM
==
0
);
...
@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
...
@@ -41,11 +41,11 @@ void rms_norm_impl(scalar_t *__restrict__ out,
}
}
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
fused_add_rms_norm_impl
(
scalar_t
*
__restrict__
input
,
void
fused_add_rms_norm_impl
(
scalar_t
*
__restrict__
input
,
scalar_t
*
__restrict__
residual
,
scalar_t
*
__restrict__
residual
,
const
scalar_t
*
__restrict__
weight
,
const
scalar_t
*
__restrict__
weight
,
const
float
epsilon
,
const
int
num_tokens
,
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
hidden_size
)
{
using
scalar_vec_t
=
vec_op
::
vec_t
<
scalar_t
>
;
using
scalar_vec_t
=
vec_op
::
vec_t
<
scalar_t
>
;
constexpr
int
VEC_ELEM_NUM
=
scalar_vec_t
::
get_elem_num
();
constexpr
int
VEC_ELEM_NUM
=
scalar_vec_t
::
get_elem_num
();
TORCH_CHECK
(
hidden_size
%
VEC_ELEM_NUM
==
0
);
TORCH_CHECK
(
hidden_size
%
VEC_ELEM_NUM
==
0
);
...
@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
...
@@ -85,24 +85,24 @@ void fused_add_rms_norm_impl(scalar_t *__restrict__ input,
}
}
}
}
}
}
}
// namespace
}
// namespace
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
weight
,
float
epsilon
)
{
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_impl"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_impl"
,
[
&
]
{
CPU_KERNEL_GUARD_IN
(
rms_norm_impl
)
CPU_KERNEL_GUARD_IN
(
rms_norm_impl
)
rms_norm_impl
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
rms_norm_impl
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
hidden_size
);
CPU_KERNEL_GUARD_OUT
(
rms_norm_impl
)
CPU_KERNEL_GUARD_OUT
(
rms_norm_impl
)
});
});
}
}
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
)
{
torch
::
Tensor
&
weight
,
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
csrc/cpu/pos_encoding.cpp
View file @
b9e12416
...
@@ -4,22 +4,21 @@
...
@@ -4,22 +4,21 @@
namespace
{
namespace
{
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
rotary_embedding_impl
(
void
rotary_embedding_impl
(
const
int64_t
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
*
__restrict__
positions
,
// [batch_size, seq_len] or
[num_tokens]
//
[num_tokens]
scalar_t
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
*
__restrict__
query
,
/// [batch_size, seq_l
en, num_heads,
head_size] or
/// head_size] or [num_tok
en
s
, num_heads,
/// [num_tokens, num_heads,
head_size]
///
head_size]
scalar_t
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
*
__restrict__
key
,
// [batch_size, seq_l
en, num_kv_heads,
head_size] or
// head_size] or [num_tok
en
s
, num_kv_heads,
// [num_tokens, num_kv_heads,
head_size]
//
head_size]
const
scalar_t
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim
// 2]
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
,
const
int
num_tokens
)
{
const
int
num_tokens
)
{
using
scalar_vec_t
=
vec_op
::
vec_t
<
scalar_t
>
;
using
scalar_vec_t
=
vec_op
::
vec_t
<
scalar_t
>
;
constexpr
int
VEC_ELEM_NUM
=
scalar_vec_t
::
get_elem_num
();
constexpr
int
VEC_ELEM_NUM
=
scalar_vec_t
::
get_elem_num
();
constexpr
int
ELEM_SIZE
=
sizeof
(
scalar_t
);
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
TORCH_CHECK
(
embed_dim
%
VEC_ELEM_NUM
==
0
);
TORCH_CHECK
(
embed_dim
%
VEC_ELEM_NUM
==
0
);
...
@@ -27,7 +26,7 @@ void rotary_embedding_impl(
...
@@ -27,7 +26,7 @@ void rotary_embedding_impl(
#pragma omp parallel for
#pragma omp parallel for
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
for
(
int
i
=
0
;
i
<
num_heads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_heads
;
++
i
)
{
const
int
head_idx
=
i
;
const
int
head_idx
=
i
;
...
@@ -95,16 +94,16 @@ void rotary_embedding_impl(
...
@@ -95,16 +94,16 @@ void rotary_embedding_impl(
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
void
rotary_embedding_gptj_impl
(
void
rotary_embedding_gptj_impl
(
const
int64_t
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
*
__restrict__
positions
,
// [batch_size, seq_len] or
[num_tokens]
//
[num_tokens]
scalar_t
scalar_t
*
__restrict__
query
,
/// [batch_size, seq_len, num_heads,
*
__restrict__
query
,
/// [batch_size, seq_l
en, num_heads,
head_size] or
/// head_size] or [num_tok
en
s
, num_heads,
/// [num_tokens, num_heads,
head_size]
///
head_size]
scalar_t
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
*
__restrict__
key
,
// [batch_size, seq_l
en, num_kv_heads,
head_size] or
// head_size] or [num_tok
en
s
, num_kv_heads,
// [num_tokens, num_kv_heads,
head_size]
//
head_size]
const
scalar_t
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim
// 2]
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
,
const
int
num_tokens
)
{
const
int
num_tokens
)
{
...
@@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl(
...
@@ -114,13 +113,13 @@ void rotary_embedding_gptj_impl(
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
i
=
0
;
i
<
num_heads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_heads
;
++
i
)
{
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cos_cache_ptr
=
cache_ptr
;
const
scalar_t
*
cos_cache_ptr
=
cache_ptr
;
const
scalar_t
*
sin_cache_ptr
=
cache_ptr
+
embed_dim
;
const
scalar_t
*
sin_cache_ptr
=
cache_ptr
+
embed_dim
;
const
int
head_idx
=
i
;
const
int
head_idx
=
i
;
const
int64_t
token_head
=
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
token_idx
*
query_stride
+
head_idx
*
head_size
;
scalar_t
*
head_query
=
token_head
+
query
;
scalar_t
*
head_query
=
token_head
+
query
;
for
(
int
j
=
0
;
j
<
embed_dim
;
j
+=
1
)
{
for
(
int
j
=
0
;
j
<
embed_dim
;
j
+=
1
)
{
const
int
rot_offset
=
j
;
const
int
rot_offset
=
j
;
const
int
x_index
=
2
*
rot_offset
;
const
int
x_index
=
2
*
rot_offset
;
...
@@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl(
...
@@ -142,12 +141,12 @@ void rotary_embedding_gptj_impl(
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
token_idx
=
0
;
token_idx
<
num_tokens
;
++
token_idx
)
{
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
for
(
int
i
=
0
;
i
<
num_kv_heads
;
++
i
)
{
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cos_cache_ptr
=
cache_ptr
;
const
scalar_t
*
cos_cache_ptr
=
cache_ptr
;
const
scalar_t
*
sin_cache_ptr
=
cache_ptr
+
embed_dim
;
const
scalar_t
*
sin_cache_ptr
=
cache_ptr
+
embed_dim
;
const
int
head_idx
=
i
;
const
int
head_idx
=
i
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
scalar_t
*
head_key
=
key
+
token_head
;
scalar_t
*
head_key
=
key
+
token_head
;
for
(
int
j
=
0
;
j
<
embed_dim
;
j
+=
1
)
{
for
(
int
j
=
0
;
j
<
embed_dim
;
j
+=
1
)
{
const
int
rot_offset
=
j
;
const
int
rot_offset
=
j
;
const
int
x_index
=
2
*
rot_offset
;
const
int
x_index
=
2
*
rot_offset
;
...
@@ -165,11 +164,11 @@ void rotary_embedding_gptj_impl(
...
@@ -165,11 +164,11 @@ void rotary_embedding_gptj_impl(
}
}
}
}
}
}
};
// namespace
};
// namespace
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
)
{
int
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
...
...
csrc/cpu/pybind.cpp
View file @
b9e12416
...
@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
...
@@ -8,66 +8,37 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11
::
module
ops
=
m
.
def_submodule
(
"ops"
,
"vLLM custom operators"
);
pybind11
::
module
ops
=
m
.
def_submodule
(
"ops"
,
"vLLM custom operators"
);
// Attention ops
// Attention ops
ops
.
def
(
ops
.
def
(
"paged_attention_v1"
,
&
paged_attention_v1
,
"paged_attention_v1"
,
"Compute the attention between an input query and the cached "
&
paged_attention_v1
,
"keys/values using PagedAttention."
);
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
// Activation ops
// Activation ops
ops
.
def
(
ops
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
"silu_and_mul"
,
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
&
silu_and_mul
,
"Activation function used in GeGLU with `none` approximation."
);
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
ops
.
def
(
"Activation function used in GeGLU with `tanh` approximation."
);
"gelu_and_mul"
,
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
&
gelu_and_mul
,
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
"Activation function used in GeGLU with `none` approximation."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Activation function used in GeGLU with `tanh` approximation."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
// Layernorm
// Layernorm
ops
.
def
(
ops
.
def
(
"rms_norm"
,
&
rms_norm
,
"rms_norm"
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
&
rms_norm
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
ops
.
def
(
ops
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
"fused_add_rms_norm"
,
"In-place fused Add and RMS Normalization"
);
&
fused_add_rms_norm
,
"In-place fused Add and RMS Normalization"
);
// Rotary embedding
// Rotary embedding
ops
.
def
(
ops
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"rotary_embedding"
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
// Cache ops
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
cache_ops
.
def
(
cache_ops
.
def
(
"swap_blocks"
,
&
swap_blocks
,
"swap_blocks"
,
"Swap in (out) the cache blocks from src to dst"
);
&
swap_blocks
,
cache_ops
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
"Copy the cache blocks from src to dst"
);
cache_ops
.
def
(
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"copy_blocks"
,
"Reshape the key and value tensors and cache them"
);
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
}
}
csrc/cuda_compat.h
View file @
b9e12416
#pragma once
#pragma once
#ifdef USE_ROCM
#ifdef USE_ROCM
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#endif
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
...
@@ -17,7 +17,8 @@
...
@@ -17,7 +17,8 @@
#endif
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor_sync(uint32_t(-1), var, lane_mask)
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) \
__shfl_xor_sync(uint32_t(-1), var, lane_mask)
#else
#else
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#define VLLM_SHFL_XOR_SYNC(var, lane_mask) __shfl_xor(var, lane_mask)
#endif
#endif
...
@@ -28,6 +29,13 @@
...
@@ -28,6 +29,13 @@
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#define VLLM_SHFL_SYNC(var, src_lane) __shfl(var, src_lane)
#endif
#endif
#ifndef USE_ROCM
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) \
__shfl_down_sync(uint32_t(-1), var, lane_delta)
#else
#define VLLM_SHFL_DOWN_SYNC(var, lane_delta) __shfl_down(var, lane_delta)
#endif
#ifndef USE_ROCM
#ifndef USE_ROCM
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
cudaFuncSetAttribute(FUNC, cudaFuncAttributeMaxDynamicSharedMemorySize, VAL)
...
@@ -35,4 +43,3 @@
...
@@ -35,4 +43,3 @@
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
#define VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize(FUNC, VAL) \
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
hipFuncSetAttribute(FUNC, hipFuncAttributeMaxDynamicSharedMemorySize, VAL)
#endif
#endif
csrc/cuda_utils.h
View file @
b9e12416
...
@@ -2,9 +2,6 @@
...
@@ -2,9 +2,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
int
get_device_attribute
(
int
get_device_attribute
(
int
attribute
,
int
device_id
);
int
attribute
,
int
device_id
);
int
get_max_shared_memory_per_block_device_attribute
(
int
get_max_shared_memory_per_block_device_attribute
(
int
device_id
);
int
device_id
);
csrc/cuda_utils_kernels.cu
View file @
b9e12416
...
@@ -2,34 +2,28 @@
...
@@ -2,34 +2,28 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#include <hip/hip_runtime_api.h>
#endif
#endif
int
get_device_attribute
(
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
attribute
,
int
device
,
value
;
int
device_id
)
if
(
device_id
<
0
)
{
{
cudaGetDevice
(
&
device
);
int
device
,
value
;
}
else
{
if
(
device_id
<
0
)
{
device
=
device_id
;
cudaGetDevice
(
&
device
);
}
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
else
{
device
);
device
=
device_id
;
return
value
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
return
value
;
}
}
int
get_max_shared_memory_per_block_device_attribute
(
int
device_id
)
{
int
get_max_shared_memory_per_block_device_attribute
(
int
attribute
;
int
device_id
)
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
{
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int
attribute
;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM
#ifdef USE_ROCM
attribute
=
hipDeviceAttributeMaxSharedMemoryPerBlock
;
attribute
=
hipDeviceAttributeMaxSharedMemoryPerBlock
;
#else
#else
attribute
=
cudaDevAttrMaxSharedMemoryPerBlockOptin
;
attribute
=
cudaDevAttrMaxSharedMemoryPerBlockOptin
;
#endif
#endif
return
get_device_attribute
(
attribute
,
device_id
);
return
get_device_attribute
(
attribute
,
device_id
);
}
}
csrc/custom_all_reduce.cu
View file @
b9e12416
...
@@ -7,11 +7,11 @@
...
@@ -7,11 +7,11 @@
// fake pointer type
// fake pointer type
using
fptr_t
=
uint64_t
;
using
fptr_t
=
uint64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
)
{
bool
full_nvlink
)
{
int
world_size
=
offsets
.
size
();
int
world_size
=
offsets
.
size
();
if
(
world_size
>
8
)
if
(
world_size
>
8
)
...
@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
...
@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std
::
memcpy
(
&
ipc_handles
[
i
],
handles
[
i
].
data
(),
sizeof
(
cudaIpcMemHandle_t
));
std
::
memcpy
(
&
ipc_handles
[
i
],
handles
[
i
].
data
(),
sizeof
(
cudaIpcMemHandle_t
));
}
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
ipc_handles
,
offsets
,
rank
,
full_nvlink
);
rank_data
.
numel
(),
ipc_handles
,
offsets
,
rank
,
full_nvlink
);
}
}
...
@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
...
@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
*/
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
return
t
.
is_contiguous
()
||
return
t
.
is_contiguous
()
||
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
t
.
numel
()
*
t
.
element_size
());
t
.
numel
()
*
t
.
element_size
());
}
}
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
full_nvlink
)
{
bool
full_nvlink
)
{
auto
inp_size
=
inp
.
numel
()
*
inp
.
element_size
();
auto
inp_size
=
inp
.
numel
()
*
inp
.
element_size
();
// custom allreduce requires input byte size to be multiples of 16
// custom allreduce requires input byte size to be multiples of 16
...
@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
...
@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return
false
;
return
false
;
}
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
cudaStream_t
stream
)
{
cudaStream_t
stream
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
switch
(
out
.
scalar_type
())
{
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
out
.
numel
());
break
;
break
;
}
}
case
at
::
ScalarType
::
Half
:
{
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
out
.
numel
());
break
;
break
;
}
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()),
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
break
;
}
}
#endif
#endif
...
@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
...
@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
}
}
}
}
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
...
@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
...
@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce
(
_fa
,
inp
,
out
,
stream
);
_all_reduce
(
_fa
,
inp
,
out
,
stream
);
}
}
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
)
{
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
...
@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
}
}
void
dispose
(
fptr_t
_fa
)
{
void
dispose
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
delete
fa
;
delete
fa
;
}
}
int
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
int
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
)
{
const
std
::
vector
<
int64_t
>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_buffer
(
handles
,
offsets
,
t
.
data_ptr
());
fa
->
register_buffer
(
handles
,
offsets
,
t
.
data_ptr
());
}
}
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
return
fa
->
get_graph_buffer_ipc_meta
();
return
fa
->
get_graph_buffer_ipc_meta
();
}
}
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
)
{
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_graph_buffers
(
handles
,
offsets
);
fa
->
register_graph_buffers
(
handles
,
offsets
);
}
}
csrc/custom_all_reduce.cuh
View file @
b9e12416
...
@@ -31,9 +31,9 @@ struct Signal {
...
@@ -31,9 +31,9 @@ struct Signal {
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
};
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
// like std::array, but aligned
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
template
<
typename
T
,
int
sz
>
...
@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
...
@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
// bfloat is disabled so we call the intrinsics directly
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
a
=
__hadd
(
a
,
b
);
a
=
__hadd
(
a
,
b
);
return
a
;
return
a
;
}
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
...
@@ -80,14 +80,14 @@ template <>
...
@@ -80,14 +80,14 @@ template <>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
return
__float2bfloat16
(
val
);
return
__float2bfloat16
(
val
);
}
}
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
a
=
__hadd
(
a
,
b
);
a
=
__hadd
(
a
,
b
);
return
a
;
return
a
;
}
}
#endif
#endif
template
<
typename
T
,
int
N
>
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>
&
packed_assign_add
(
array_t
<
T
,
N
>
&
a
,
array_t
<
T
,
N
>
b
)
{
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
assign_add
(
a
.
data
[
i
],
b
.
data
[
i
]);
assign_add
(
a
.
data
[
i
],
b
.
data
[
i
]);
...
@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
...
@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
// other volatile writes.
template
<
int
ngpus
>
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
int
rank
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
// reset flag for next time
...
@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
...
@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
// wait until we got true from all ranks
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
])
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]);
;
}
}
__syncthreads
();
__syncthreads
();
}
}
...
@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
...
@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
int
rank
)
{
__syncthreads
();
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
// the memory model.
if
constexpr
(
!
final_sync
)
__threadfence_system
();
if
constexpr
(
!
final_sync
)
__threadfence_system
();
if
(
threadIdx
.
x
<
ngpus
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
// reset flag for next time
...
@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
...
@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
// wait until we got true from all ranks
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
])
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]);
;
}
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
if
constexpr
(
!
final_sync
)
__syncthreads
();
}
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
1
;
i
<
ngpus
;
i
++
)
{
for
(
int
i
=
1
;
i
<
ngpus
;
i
++
)
{
...
@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
...
@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template
<
typename
T
,
int
ngpus
>
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
using
A
=
typename
packed_t
<
T
>::
A
;
...
@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
}
template
<
typename
P
>
template
<
typename
P
>
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
}
template
<
typename
T
,
int
ngpus
>
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
...
@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int
start
=
rank
*
part
;
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
}
auto
tmp_out
=
tmps
[
0
];
auto
tmp_out
=
tmps
[
0
];
...
@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
...
@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
}
}
}
...
@@ -261,14 +258,14 @@ class CustomAllreduce {
...
@@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers
// below are device pointers
RankSignals
sg_
;
RankSignals
sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
Signal
*
self_sg_
;
// stores the registered device pointers from all ranks
// stores the registered device pointers from all ranks
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
// a map from IPC handles to opened IPC pointers
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
* meta is a pointer to device metadata and temporary buffer for allreduce.
...
@@ -279,22 +276,22 @@ class CustomAllreduce {
...
@@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
* are passed in from the constructor
*/
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
full_nvlink_
(
full_nvlink
),
self_sg_
(
meta
),
self_sg_
(
meta
),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Signal
*
rank_sg
;
Signal
*
rank_sg
;
if
(
i
!=
rank_
)
{
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
handle
+=
offsets
[
i
];
handle
+=
offsets
[
i
];
rank_sg
=
(
Signal
*
)
handle
;
rank_sg
=
(
Signal
*
)
handle
;
}
else
{
}
else
{
rank_sg
=
self_sg_
;
rank_sg
=
self_sg_
;
}
}
...
@@ -302,13 +299,13 @@ class CustomAllreduce {
...
@@ -302,13 +299,13 @@ class CustomAllreduce {
}
}
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
if
(
new_handle
)
{
char
*
ipc_ptr
;
char
*
ipc_ptr
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
it
->
second
=
ipc_ptr
;
}
}
...
@@ -323,7 +320,7 @@ class CustomAllreduce {
...
@@ -323,7 +320,7 @@ class CustomAllreduce {
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
graph_unreg_buffers_
[
i
];
auto
ptr
=
graph_unreg_buffers_
[
i
];
void
*
base_ptr
;
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// note: must share the base address of each allocation, or we get wrong
// address
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
if
(
cuPointerGetAttribute
(
&
base_ptr
,
...
@@ -331,8 +328,8 @@ class CustomAllreduce {
...
@@ -331,8 +328,8 @@ class CustomAllreduce {
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
cudaIpcGetMemHandle
(
CUDACHECK
(
cudaIpcGetMemHandle
(
(
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
(
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
}
return
std
::
make_pair
(
handles
,
offsets
);
return
std
::
make_pair
(
handles
,
offsets
);
}
}
...
@@ -344,13 +341,13 @@ class CustomAllreduce {
...
@@ -344,13 +341,13 @@ class CustomAllreduce {
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>
&
handles
,
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
void
*
self
)
{
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
check_rank_data_capacity
();
check_rank_data_capacity
();
RankData
data
;
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
if
(
i
!=
rank_
)
{
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
handle
+=
offsets
[
i
];
handle
+=
offsets
[
i
];
data
.
ptrs
[
i
]
=
handle
;
data
.
ptrs
[
i
]
=
handle
;
}
else
{
}
else
{
...
@@ -371,17 +368,17 @@ class CustomAllreduce {
...
@@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
// mechanism so overhead should be small.
void
register_graph_buffers
(
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
)
{
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
graph_unreg_buffers_
[
i
];
auto
self_ptr
=
graph_unreg_buffers_
[
i
];
auto
&
rd
=
rank_data
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
if
(
j
!=
rank_
)
{
if
(
j
!=
rank_
)
{
char
*
handle
=
char
*
handle
=
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
rd
.
ptrs
[
j
]
=
handle
;
...
@@ -405,7 +402,7 @@ class CustomAllreduce {
...
@@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
* will cause contention on NVLink bus.
*/
*/
template
<
typename
T
>
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
int
threads
=
512
,
int
block_limit
=
36
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
if
(
size
%
d
!=
0
)
...
@@ -418,7 +415,7 @@ class CustomAllreduce {
...
@@ -418,7 +415,7 @@ class CustomAllreduce {
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
if
(
status
==
cudaStreamCaptureStatusActive
)
{
...
...
csrc/custom_all_reduce_test.cu
View file @
b9e12416
...
@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
...
@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
data
[
idx
]
=
myRank
*
0.11
f
;
data
[
idx
]
=
myRank
*
0.11
f
;
...
@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
...
@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
convert_data
(
const
T
*
data1
,
const
T
*
data2
,
double
*
fdata1
,
__global__
void
convert_data
(
const
T
*
data1
,
const
T
*
data2
,
double
*
fdata1
,
double
*
fdata2
,
int
size
)
{
double
*
fdata2
,
int
size
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
fdata1
[
idx
]
=
data1
[
idx
];
fdata1
[
idx
]
=
data1
[
idx
];
...
@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
...
@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
}
}
}
}
__global__
void
init_rand
(
curandState_t
*
state
,
int
size
,
int
nRanks
)
{
__global__
void
init_rand
(
curandState_t
*
state
,
int
size
,
int
nRanks
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
...
@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
...
@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
}
}
template
<
typename
T
>
template
<
typename
T
>
__global__
void
gen_data
(
curandState_t
*
state
,
T
*
data
,
double
*
ground_truth
,
__global__
void
gen_data
(
curandState_t
*
state
,
T
*
data
,
double
*
ground_truth
,
int
myRank
,
int
nRanks
,
int
size
)
{
int
myRank
,
int
nRanks
,
int
size
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
...
@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
}
}
template
<
typename
T
>
template
<
typename
T
>
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
int
data_size
,
bool
performance_test
)
{
int
data_size
,
bool
performance_test
)
{
T
*
result
;
T
*
result
;
cudaStream_t
stream
;
cudaStream_t
stream
;
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaMalloc
(
&
result
,
data_size
*
sizeof
(
T
)));
CUDACHECK
(
cudaMalloc
(
&
result
,
data_size
*
sizeof
(
T
)));
...
@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
...
@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
data_handles
[
8
];
cudaIpcMemHandle_t
data_handles
[
8
];
vllm
::
Signal
*
buffer
;
vllm
::
Signal
*
buffer
;
T
*
self_data_copy
;
T
*
self_data_copy
;
/**
/**
* Allocate IPC buffer
* Allocate IPC buffer
*
*
...
@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
...
@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE
,
data_handles
,
sizeof
(
cudaIpcMemHandle_t
),
MPI_BYTE
,
data_handles
,
sizeof
(
cudaIpcMemHandle_t
),
MPI_BYTE
,
MPI_COMM_WORLD
));
MPI_BYTE
,
MPI_COMM_WORLD
));
void
*
rank_data
;
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
offsets
,
myRank
);
offsets
,
myRank
);
auto
*
self_data
=
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
// hack buffer registration
{
{
std
::
vector
<
std
::
string
>
handles
;
std
::
vector
<
std
::
string
>
handles
;
handles
.
reserve
(
nRanks
);
handles
.
reserve
(
nRanks
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
handles
.
emplace_back
(
begin
,
end
);
}
}
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
...
@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
...
@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
}
double
*
ground_truth
;
double
*
ground_truth
;
CUDACHECK
(
cudaMallocHost
(
&
ground_truth
,
data_size
*
sizeof
(
double
)));
CUDACHECK
(
cudaMallocHost
(
&
ground_truth
,
data_size
*
sizeof
(
double
)));
curandState_t
*
states
;
curandState_t
*
states
;
CUDACHECK
(
cudaMalloc
(
&
states
,
sizeof
(
curandState_t
)
*
nRanks
*
data_size
));
CUDACHECK
(
cudaMalloc
(
&
states
,
sizeof
(
curandState_t
)
*
nRanks
*
data_size
));
init_rand
<<<
108
,
1024
,
0
,
stream
>>>
(
states
,
data_size
,
nRanks
);
init_rand
<<<
108
,
1024
,
0
,
stream
>>>
(
states
,
data_size
,
nRanks
);
gen_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
states
,
self_data
,
ground_truth
,
myRank
,
gen_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
states
,
self_data
,
ground_truth
,
myRank
,
...
@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
...
@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK
(
cudaStreamDestroy
(
stream
));
CUDACHECK
(
cudaStreamDestroy
(
stream
));
}
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
int
nRanks
,
myRank
;
int
nRanks
,
myRank
;
MPICHECK
(
MPI_Init
(
&
argc
,
&
argv
));
MPICHECK
(
MPI_Init
(
&
argc
,
&
argv
));
MPICHECK
(
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
myRank
));
MPICHECK
(
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
myRank
));
...
@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
...
@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId
id
;
ncclUniqueId
id
;
ncclComm_t
comm
;
ncclComm_t
comm
;
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
...
...
csrc/dispatch_utils.h
View file @
b9e12416
...
@@ -6,32 +6,30 @@
...
@@ -6,32 +6,30 @@
#include <torch/extension.h>
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)
\
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)
\
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)
\
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(
\
AT_DISPATCH_SWITCH(
TYPE, NAME,
\
TYPE, NAME,
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)
\
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu
View file @
b9e12416
...
@@ -11,26 +11,24 @@
...
@@ -11,26 +11,24 @@
#include <hip/hip_bf16.h>
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
#include <hip/hip_fp16.h>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
#endif
namespace
vllm
{
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
variance
+=
x
*
x
;
}
}
variance
=
blockReduceSum
<
float
>
(
variance
);
variance
=
blockReduceSum
<
float
>
(
variance
);
...
@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
...
@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads
();
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
to be implemented for now because the relevant type conversion
...
@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel(
...
@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel(
Each struct should have the member static constexpr bool `exists`:
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
If true, the struct should be fully defined as shown in the examples below.
*/
*/
template
<
typename
torch_type
>
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
// CUDA < 12.0 runs into issues with packed type conversion
template
<
>
template
<
>
struct
_typeConvert
<
c10
::
Half
>
{
struct
_typeConvert
<
c10
::
Half
>
{
static
constexpr
bool
exists
=
true
;
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__half
;
using
hip_type
=
__half
;
using
packed_hip_type
=
__half2
;
using
packed_hip_type
=
__half2
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__half2float
(
x
);
}
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__half2float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
return
__half22float2
(
x
);
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
};
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
// TODO: Add in ROCm support once public headers handle bf16 maturely
template
<
>
template
<
>
struct
_typeConvert
<
c10
::
BFloat16
>
{
struct
_typeConvert
<
c10
::
BFloat16
>
{
static
constexpr
bool
exists
=
true
;
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__nv_bfloat16
;
using
hip_type
=
__nv_bfloat16
;
using
packed_hip_type
=
__nv_bfloat162
;
using
packed_hip_type
=
__nv_bfloat162
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float
convert
(
hip_type
x
)
{
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
return
__bfloat162float
(
x
);
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
};
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16Vec
{
struct
alignas
(
16
)
_f16Vec
{
/* Not theoretically necessary that width is a power of 2 but should
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
almost always be the case for optimization purposes */
static_assert
(
width
>
0
&&
(
width
&
(
width
-
1
))
==
0
,
static_assert
(
width
>
0
&&
(
width
&
(
width
-
1
))
==
0
,
"Width is not a positive power of 2!"
);
"Width is not a positive power of 2!"
);
using
Converter
=
_typeConvert
<
scalar_t
>
;
using
Converter
=
_typeConvert
<
scalar_t
>
;
...
@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
...
@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__
_f16Vec
&
operator
+=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
__device__
_f16Vec
&
operator
+=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
}
else
{
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
data
[
i
]
+=
other
.
data
[
i
];
}
}
return
*
this
;
return
*
this
;
}
}
__device__
_f16Vec
&
operator
*=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
__device__
_f16Vec
&
operator
*=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
}
else
{
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
data
[
i
]
*=
other
.
data
[
i
];
}
}
return
*
this
;
return
*
this
;
}
}
__device__
_f16Vec
&
operator
*=
(
const
float
scale
)
{
__device__
_f16Vec
&
operator
*=
(
const
float
scale
)
{
if
constexpr
(
width
%
2
==
0
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
temp_f
.
x
*=
scale
;
temp_f
.
x
*=
scale
;
temp_f
.
y
*=
scale
;
temp_f
.
y
*=
scale
;
T2
temp
=
Converter
::
convert
(
temp_f
);
T2
temp
=
Converter
::
convert
(
temp_f
);
data
[
i
]
=
temp
.
x
;
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
}
else
{
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
temp
=
Converter
::
convert
(
data
[
i
])
*
scale
;
float
temp
=
Converter
::
convert
(
data
[
i
])
*
scale
;
data
[
i
]
=
Converter
::
convert
(
temp
);
data
[
i
]
=
Converter
::
convert
(
temp
);
...
@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
...
@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__
float
sum_squares
()
const
{
__device__
float
sum_squares
()
const
{
float
result
=
0.0
f
;
float
result
=
0.0
f
;
if
constexpr
(
width
%
2
==
0
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
result
+=
z
.
x
*
z
.
x
+
z
.
y
*
z
.
y
;
result
+=
z
.
x
*
z
.
x
+
z
.
y
*
z
.
y
;
}
}
}
else
{
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
x
=
Converter
::
convert
(
data
[
i
]);
float
x
=
Converter
::
convert
(
data
[
i
]);
result
+=
x
*
x
;
result
+=
x
*
x
;
...
@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
...
@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
packed and vectorized operations, which help with the
memory latency bottleneck. */
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
num_tokens
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
...
@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
...
@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
input_v
=
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
...
@@ -215,10 +229,11 @@ __global__ std::enable_if_t<
...
@@ -215,10 +229,11 @@ __global__ std::enable_if_t<
residual_v
[
id
]
=
temp
;
residual_v
[
id
]
=
temp
;
}
}
/* Keep the following if-else block in sync with the
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
}
...
@@ -233,52 +248,50 @@ __global__ std::enable_if_t<
...
@@ -233,52 +248,50 @@ __global__ std::enable_if_t<
}
}
}
}
/* Generic fused_add_rms_norm_kernel
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
The width field is not used here but necessary for other specializations.
*/
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
}
/* Keep the following if-else block in sync with the
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
}
__syncthreads
();
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
}
}
// namespace vllm
}
// namespace vllm
void
rms_norm
(
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
@@ -286,40 +299,27 @@ void rms_norm(
...
@@ -286,40 +299,27 @@ void rms_norm(
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
input
.
scalar_type
(),
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
"rms_norm_kernel"
,
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
[
&
]
{
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
});
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
"fused_add_rms_norm_kernel", \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
[&] { \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
vllm::fused_add_rms_norm_kernel \
residual.data_ptr<scalar_t>(), \
<scalar_t, width><<<grid, block, 0, stream>>>( \
weight.data_ptr<scalar_t>(), epsilon, \
input.data_ptr<scalar_t>(), \
num_tokens, hidden_size); \
residual.data_ptr<scalar_t>(), \
});
weight.data_ptr<scalar_t>(), \
epsilon, \
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
num_tokens, \
torch
::
Tensor
&
residual
,
// [..., hidden_size]
hidden_size); \
torch
::
Tensor
&
weight
,
// [hidden_size]
});
float
epsilon
)
{
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
@@ -342,8 +342,8 @@ void fused_add_rms_norm(
...
@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
\
bool
ptrs_are_aligned
=
&&
wt_ptr
%
16
==
0
;
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
}
else
{
...
...
csrc/moe/moe_ops.cpp
View file @
b9e12416
...
@@ -3,5 +3,6 @@
...
@@ -3,5 +3,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
}
csrc/moe/moe_ops.h
View file @
b9e12416
...
@@ -2,8 +2,6 @@
...
@@ -2,8 +2,6 @@
#include <torch/extension.h>
#include <torch/extension.h>
void
topk_softmax
(
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
gating_output
);
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
csrc/moe_align_block_size_kernels.cu
View file @
b9e12416
...
@@ -7,119 +7,128 @@
...
@@ -7,119 +7,128 @@
#include "cuda_compat.h"
#include "cuda_compat.h"
#include "dispatch_utils.h"
#include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
#define CEILDIV(x,
y) (((x) + (y) - 1) / (y))
namespace
vllm
{
namespace
vllm
{
namespace
{
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
// don't worry about overflow because num_experts is relatively small
int32_t
col
)
{
return
row
*
total_col
+
col
;
// don't worry about overflow because num_experts is relatively small
}
return
row
*
total_col
+
col
;
}
}
}
// namespace
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
num_experts
,
int32_t
block_size
,
int32_t
block_size
,
size_t
numel
)
{
size_t
numel
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
extern
__shared__
int32_t
shared_mem
[];
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
tokens_cnts
=
int32_t
*
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
int32_t
*
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
int32_t
*
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
num_experts
;
// 1d tensor with shape (num_experts + 1)
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
/**
}
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
/**
* to expert expert_index.
* In the first step we compute token_cnts[thread_index + 1][expert_index],
*/
* which counts how many tokens in the token shard of thread_index are
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
* assigned to expert expert_index.
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
*/
}
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
__syncthreads
();
}
// For each expert we accumulate the token counts from the different threads.
__syncthreads
();
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
// For each expert we accumulate the token counts from the different threads.
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
}
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
__syncthreads
();
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
__syncthreads
();
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
// We accumulate the token counts of all experts in thread 0.
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
if
(
threadIdx
.
x
==
0
)
{
}
cumsum
[
0
]
=
0
;
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
}
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
__syncthreads
();
block_size
)
*
block_size
;
/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
/**
}
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
__syncthreads
();
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
/**
*/
* For each expert, each thread processes the tokens of the corresponding
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
* blocks and stores the corresponding expert_id for each block.
int32_t
expert_id
=
topk_ids
[
i
];
*/
/** The cumsum[expert_id] stores the starting index of the tokens that the
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
i
+=
block_size
)
{
* stores the indices of the tokens processed by the expert with expert_id within
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
* the current thread's token shard.
}
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
/**
sorted_token_ids
[
rank_post_pad
]
=
i
;
* Each thread processes a token shard, calculating the index of each token
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
* after sorting by expert number. Given the example topk_ids =
}
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
}
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
}
}
// namespace vllm
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
int
block_size
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
torch
::
Tensor
experts_ids
,
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
torch
::
Tensor
num_tokens_post_pad
)
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
VLLM_DISPATCH_INTEGRAL_TYPES
(
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// tensors
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
const
int32_t
shared_mem
=
const
int32_t
shared_mem
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
// set dynamic shared mem
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
>
;
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
>
;
AT_CUDA_CHECK
(
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
topk_ids
.
numel
());
});
});
}
}
csrc/ops.h
View file @
b9e12416
...
@@ -3,204 +3,139 @@
...
@@ -3,204 +3,139 @@
#include <torch/extension.h>
#include <torch/extension.h>
void
paged_attention_v1
(
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
torch
::
Tensor
&
value_cache
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
int
num_kv_heads
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
,
const
int
tp_rank
,
float
scale
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
torch
::
Tensor
&
block_tables
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
);
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
paged_attention_v2
(
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
torch
::
Tensor
&
query
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
key_cache
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
,
const
int
tp_rank
,
torch
::
Tensor
&
value_cache
,
const
int
blocksparse_local_blocks
,
const
int
blocksparse_vert_stride
,
int
num_kv_heads
,
const
int
blocksparse_block_size
,
const
int
blocksparse_head_sliding_step
);
float
scale
,
torch
::
Tensor
&
block_tables
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
seq_lens
,
float
epsilon
);
int
block_size
,
int
max_seq_len
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
torch
::
Tensor
&
weight
,
float
epsilon
);
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
void
rms_norm
(
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
weight
,
torch
::
Tensor
&
key
,
int
head_size
,
float
epsilon
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
void
fused_add_rms_norm
(
torch
::
Tensor
&
cos_sin_cache_offsets
);
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
weight
,
float
epsilon
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
rotary_embedding
(
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
key
,
int
head_size
,
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
#ifndef USE_ROCM
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
torch
::
Tensor
&
scales
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
);
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
const
torch
::
Tensor
&
codebooks
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
const
torch
::
Tensor
&
codebook_partition_sizes
int
split_k_iters
);
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
torch
::
Tensor
_kernel
,
int
thy
);
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
int
split_k_iters
);
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
_zeros
,
torch
::
Tensor
&
b_scales
,
int
split_k_iters
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int
thx
,
int64_t
size_m
,
int64_t
size_n
,
int
thy
);
int64_t
size_k
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
torch
::
Tensor
&
b_scales
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
size_k
,
bool
is_k_full
);
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
);
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
#endif
#endif
void
squeezellm_gemm
(
// void static_scaled_int8_quant(torch::Tensor& out, torch::Tensor& input,
torch
::
Tensor
vec
,
// float scale);
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
a
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
torch
::
Tensor
b_gptq_qzeros
,
bool
use_exllama
,
int
bit
);
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
bool
use_exllama
,
int
bit
);
// void static_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
// torch::Tensor& scale);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
// void dynamic_scaled_fp8_quant(torch::Tensor& out, torch::Tensor& input,
torch
::
Tensor
q_perm
,
// torch::Tensor& scale);
int
bit
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
// void static_scaled_fp8_quant(
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
// torch::Tensor& out,
torch
::
Tensor
experts_ids
,
// torch::Tensor& input,
torch
::
Tensor
num_tokens_post_pad
);
// torch::Tensor& scale);
// void dynamic_scaled_fp8_quant(
// torch::Tensor& out,
// torch::Tensor& input,
// torch::Tensor& scale);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
#ifndef USE_ROCM
using
fptr_t
=
uint64_t
;
using
fptr_t
=
uint64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
);
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
full_nvlink
);
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
void
dispose
(
fptr_t
_fa
);
int
meta_size
();
int
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
);
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
fptr_t
_fa
);
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
#endif
csrc/pos_encoding_kernels.cu
View file @
b9e12416
...
@@ -7,14 +7,10 @@
...
@@ -7,14 +7,10 @@
namespace
vllm
{
namespace
vllm
{
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_token_rotary_embedding
(
inline
__device__
void
apply_token_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
if
(
IS_NEOX
)
{
...
@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
...
@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
// head_size] or [num_tokens, num_heads,
const
scalar_t
*
cache_ptr
,
// head_size]
const
int
head_size
,
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
const
int
num_heads
,
// head_size] or [num_tokens, num_kv_heads,
const
int
num_kv_heads
,
// head_size]
const
int
rot_dim
,
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
token_idx
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int64_t
key_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
...
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
sin_ptr
,
rot_offset
,
embed_dim
);
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
const
int
nk
=
num_kv_heads
*
embed_dim
;
...
@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
...
@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
const
int
head_idx
=
i
/
embed_dim
;
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
sin_ptr
,
rot_offset
,
embed_dim
);
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
// [num_tokens]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
// head_size] or [num_tokens, num_heads,
const
int
rot_dim
,
// head_size]
const
int64_t
query_stride
,
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
const
int64_t
key_stride
,
// head_size] or [num_tokens, num_kv_heads,
const
int
num_heads
,
// head_size]
const
int
num_kv_heads
,
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
int
head_size
)
{
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
batched_rotary_embedding_kernel
(
__global__
void
batched_rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
// [num_tokens]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
// head_size] or [num_tokens, num_heads,
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len] or [num_tokens]
// head_size]
const
int
rot_dim
,
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
const
int64_t
query_stride
,
// head_size] or [num_tokens, num_kv_heads,
const
int64_t
key_stride
,
// head_size]
const
int
num_heads
,
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
const
int
num_kv_heads
,
// 2]
const
int
head_size
)
{
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
pos
=
positions
[
token_idx
];
int64_t
cos_sin_cache_offset
=
cos_sin_cache_offsets
[
token_idx
];
int64_t
cos_sin_cache_offset
=
cos_sin_cache_offsets
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
}
}
// namespace vllm
}
// namespace vllm
void
rotary_embedding
(
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_heads * head_size]
int
head_size
,
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
// [num_tokens, num_kv_heads * head_size]
bool
is_neox
)
{
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int64_t
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int64_t
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
...
@@ -135,36 +141,21 @@ void rotary_embedding(
...
@@ -135,36 +141,21 @@ void rotary_embedding(
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
query
.
scalar_type
(),
if
(
is_neox
)
{
"rotary_embedding"
,
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
[
&
]
{
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
if
(
is_neox
)
{
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
positions
.
data_ptr
<
int64_t
>
(),
}
else
{
query
.
data_ptr
<
scalar_t
>
(),
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
key
.
data_ptr
<
scalar_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
query_stride
,
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
key_stride
,
head_size
);
num_heads
,
}
num_kv_heads
,
});
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
}
/*
/*
...
@@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
...
@@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
and process in batched manner.
*/
*/
void
batched_rotary_embedding
(
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
// [num_tokens, num_heads * head_size]
int
head_size
,
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
// [num_tokens, num_kv_heads * head_size]
bool
is_neox
,
int
head_size
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
)
{
)
{
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
...
@@ -191,36 +183,21 @@ void batched_rotary_embedding(
...
@@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
query
.
scalar_type
(),
if
(
is_neox
)
{
"rotary_embedding"
,
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
[
&
]
{
<<<
grid
,
block
,
0
,
stream
>>>
(
if
(
is_neox
)
{
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
positions
.
data_ptr
<
int64_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
query
.
data_ptr
<
scalar_t
>
(),
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
key
.
data_ptr
<
scalar_t
>
(),
}
else
{
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
<<<
grid
,
block
,
0
,
stream
>>>
(
rot_dim
,
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
query_stride
,
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
key_stride
,
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
num_heads
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
num_kv_heads
,
}
head_size
);
});
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
}
csrc/punica/bgmv/bgmv_config.h
View file @
b9e12416
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -28,6 +28,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 2816) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3328) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 4096) \
...
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -36,6 +37,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6400) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6848) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 7168) \
...
@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -53,6 +55,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 22016) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27392) \
f(in_T, out_T, W_T, narrow, 27648) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32256) \
...
@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -96,6 +99,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2752, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 2816, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3072, narrow) \
f(in_T, out_T, W_T, 3328, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3456, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 3584, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
f(in_T, out_T, W_T, 4096, narrow) \
...
@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -104,6 +108,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5504, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 5632, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6144, narrow) \
f(in_T, out_T, W_T, 6400, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6848, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 6912, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
f(in_T, out_T, W_T, 7168, narrow) \
...
@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -121,6 +126,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 22016, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 24576, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27392, narrow) \
f(in_T, out_T, W_T, 27648, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 28672, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32000, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
f(in_T, out_T, W_T, 32256, narrow) \
...
...
csrc/punica/bgmv/bgmv_impl.cuh
View file @
b9e12416
#pragma once
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/CUDAContext.h>
#ifndef USE_ROCM
#include <cooperative_groups.h>
#include <cooperative_groups.h>
#else
#include <hip/hip_cooperative_groups.h>
#endif
#ifndef USE_ROCM
#include <cuda/pipeline>
#include <cuda/pipeline>
#endif
#include <cuda_runtime.h>
#include <cuda_runtime.h>
#include <iostream>
#include <iostream>
#include <stdio.h>
#include <stdio.h>
...
@@ -11,6 +17,24 @@
...
@@ -11,6 +17,24 @@
namespace
cg
=
cooperative_groups
;
namespace
cg
=
cooperative_groups
;
#ifdef USE_ROCM
template
<
size_t
len
>
__host__
__device__
inline
void
*
memcpy_blocking
(
void
*
dst
,
const
void
*
src
)
{
// Does not handle the case of long datatypes
char
*
d
=
reinterpret_cast
<
char
*>
(
dst
);
const
char
*
s
=
reinterpret_cast
<
const
char
*>
(
src
);
size_t
i
=
0
;
#pragma unroll
for
(
i
=
0
;
i
<
len
;
++
i
)
{
d
[
i
]
=
s
[
i
];
}
return
dst
;
}
#endif
#ifndef USE_ROCM
// nthrs = (32, 4)
// nthrs = (32, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
size_t
X_copy_size
,
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
size_t
X_copy_size
,
size_t
W_copy_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
size_t
W_copy_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
...
@@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -141,6 +165,81 @@ bgmv_shrink_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
}
}
}
}
#else
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
size_t
X_copy_size
,
size_t
W_copy_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_shrink_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
size_t
j
=
blockIdx
.
x
;
constexpr
size_t
tile_size
=
tx
*
ty
*
vec_size
;
constexpr
size_t
num_tiles
=
(
feat_in
+
tile_size
-
1
)
/
tile_size
;
__shared__
float
y_warpwise
[
ty
];
float
y
=
0
;
vec_t
<
in_T
,
vec_size
>
x_vec
;
vec_t
<
W_T
,
vec_size
>
w_vec
;
size_t
tile_idx
;
#pragma unroll
for
(
tile_idx
=
0
;
tile_idx
<
num_tiles
;
++
tile_idx
)
{
if
(
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
+
1
)
*
vec_size
-
1
<
feat_in
)
{
x_vec
.
load
(
X
+
(
batch_idx
*
feat_in
)
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
}
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
convert_type
<
W_T
,
float
>
(
w_vec
[
i
])
*
convert_type
<
in_T
,
float
>
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
VLLM_SHFL_DOWN_SYNC
(
sum
,
offset
);
}
__syncthreads
();
if
(
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
+
1
)
*
vec_size
-
1
<
feat_in
)
{
y
+=
sum
;
}
}
if
(
threadIdx
.
x
==
0
)
{
y_warpwise
[
threadIdx
.
y
]
=
y
;
}
__syncthreads
();
float
y_write
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y_write
+=
y_warpwise
[
i
];
}
// write Y;
if
(
threadIdx
.
x
==
0
&&
threadIdx
.
y
==
0
)
{
size_t
y_idx
=
batch_idx
*
full_y_size
+
y_offset
+
j
;
Y
[
y_idx
]
=
vllm_add
<
out_T
>
(
Y
[
y_idx
],
convert_type
<
float
,
out_T
>
(
y_write
));
}
}
#endif
// nthrs = (2, 16, 4)
// nthrs = (2, 16, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
int
tx
,
int
ty
,
int
tz
,
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
typename
in_T
,
typename
out_T
,
typename
W_T
>
...
@@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -172,7 +271,11 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
float
sum
=
0.
f
;
float
sum
=
0.
f
;
#pragma unroll
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
#ifndef USE_ROCM
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
#else
sum
+=
convert_type
<
W_T
,
float
>
(
w_vec
[
i
])
*
convert_type
<
in_T
,
float
>
(
x_vec
[
i
])
*
scale
;
#endif
}
}
cg
::
thread_block_tile
g
=
cg
::
tiled_partition
<
tx
>
(
block
);
cg
::
thread_block_tile
g
=
cg
::
tiled_partition
<
tx
>
(
block
);
...
@@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -183,8 +286,14 @@ bgmv_expand_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
sum
=
g
.
shfl
(
sum
,
0
);
sum
=
g
.
shfl
(
sum
,
0
);
if
(
threadIdx
.
x
==
0
)
{
if
(
threadIdx
.
x
==
0
)
{
#ifndef USE_ROCM
Y
[
batch_idx
*
full_y_size
+
y_offset
+
tile_idx
*
(
tz
*
ty
)
+
Y
[
batch_idx
*
full_y_size
+
y_offset
+
tile_idx
*
(
tz
*
ty
)
+
threadIdx
.
z
*
ty
+
threadIdx
.
y
]
+=
static_cast
<
out_T
>
(
sum
);
threadIdx
.
z
*
ty
+
threadIdx
.
y
]
+=
static_cast
<
out_T
>
(
sum
);
#else
size_t
y_idx
=
batch_idx
*
full_y_size
+
y_offset
+
tile_idx
*
(
tz
*
ty
)
+
threadIdx
.
z
*
ty
+
threadIdx
.
y
;
Y
[
y_idx
]
=
vllm_add
<
out_T
>
(
Y
[
y_idx
],
convert_type
<
float
,
out_T
>
(
sum
));
#endif
}
}
}
}
...
@@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -236,6 +345,7 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
scale
);
scale
);
}
}
}
else
{
}
else
{
#ifndef USE_ROCM
static_assert
(
feat_in
%
(
vec_size
*
32
)
==
0
||
static_assert
(
feat_in
%
(
vec_size
*
32
)
==
0
||
feat_in
%
(
vec_size
*
16
)
==
0
||
feat_in
%
(
vec_size
*
16
)
==
0
||
feat_in
%
(
vec_size
*
8
)
==
0
);
feat_in
%
(
vec_size
*
8
)
==
0
);
...
@@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
...
@@ -279,6 +389,50 @@ void bgmv_kernel(out_T *__restrict__ Y, const in_T *__restrict__ X,
full_y_size
,
num_layers
,
layer_idx
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
scale
);
}
}
#else
constexpr
size_t
rocm_warp_size
=
warpSize
;
#define CHECK_INPUT_TILEABLE_BY(vec_size_) \
feat_in % (rocm_warp_size * vec_size_) == 0
#define LAUNCH_BGMV_SHRINK_KERNELS_ROCM(factor_, vec_size_, tx_, ty_) \
if constexpr (CHECK_INPUT_TILEABLE_BY(factor_)) { \
constexpr size_t vec_size_shrink = vec_size_; \
constexpr int tx = tx_; \
constexpr int ty = ty_; \
dim3 nblks(feat_out, batch_size); \
dim3 nthrs(tx, ty); \
bgmv_shrink_kernel<feat_in, feat_out, vec_size_shrink, \
vec_size_shrink * sizeof(in_T), \
vec_size_shrink * sizeof(W_T), \
tx, ty, tz> \
<<<nblks, nthrs, 0, stream>>>(Y, X, W, indicies, y_offset, \
full_y_size, num_layers, layer_idx, \
scale); \
}
static_assert
(
CHECK_INPUT_TILEABLE_BY
(
32
)
||
CHECK_INPUT_TILEABLE_BY
(
16
)
||
CHECK_INPUT_TILEABLE_BY
(
8
)
||
CHECK_INPUT_TILEABLE_BY
(
4
)
||
CHECK_INPUT_TILEABLE_BY
(
2
)
||
CHECK_INPUT_TILEABLE_BY
(
1
));
LAUNCH_BGMV_SHRINK_KERNELS_ROCM
(
32
,
vec_size
,
rocm_warp_size
,
32
/
vec_size
)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM
(
16
,
vec_size
,
rocm_warp_size
,
16
/
vec_size
)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM
(
8
,
vec_size
,
rocm_warp_size
,
8
/
vec_size
)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM
(
4
,
vec_size
,
rocm_warp_size
/
(
vec_size
/
4
),
vec_size
/
4
)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM
(
2
,
vec_size
,
rocm_warp_size
/
(
vec_size
/
2
),
vec_size
/
2
)
else
LAUNCH_BGMV_SHRINK_KERNELS_ROCM
(
1
,
vec_size
,
rocm_warp_size
/
(
vec_size
/
1
),
vec_size
/
1
)
#undef CHECK_INPUT_TILEABLE_BY
#undef LAUNCH_BGMV_SHRINK_KERNELS_ROCM
#endif
}
}
}
}
...
...
csrc/punica/bgmv/vec_dtypes.cuh
View file @
b9e12416
#ifndef VEC_DTYPES_CUH_
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#include <cuda_fp8.h>
#endif
#endif
...
@@ -10,6 +8,9 @@
...
@@ -10,6 +8,9 @@
#include <type_traits>
#include <type_traits>
#include "../type_convert.h"
#include "../../cuda_compat.h"
#define FLASHINFER_INLINE \
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__
inline __attribute__((always_inline)) __device__ __host__
...
...
csrc/punica/punica_ops.c
c
→
csrc/punica/punica_ops.c
u
View file @
b9e12416
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
#include <c10/cuda/CUDAGuard.h>
#include <cstdint>
#include <cstdint>
#include "type_convert.h"
#include "../cuda_compat.h"
#include "bgmv/bgmv_config.h"
#include "bgmv/bgmv_config.h"
namespace
{
//====== utils ======
//====== utils ======
...
@@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
...
@@ -568,15 +567,3 @@ void dispatch_bgmv_low_level(torch::Tensor y, torch::Tensor x, torch::Tensor w,
TORCH_CHECK
(
ok
,
"No suitable kernel."
,
" h_in="
,
h_in
,
" h_out="
,
h_out
,
TORCH_CHECK
(
ok
,
"No suitable kernel."
,
" h_in="
,
h_in
,
" h_out="
,
h_out
,
" dtype="
,
x
.
scalar_type
(),
" out_dtype="
,
y
.
scalar_type
());
" dtype="
,
x
.
scalar_type
(),
" out_dtype="
,
y
.
scalar_type
());
}
}
}
// namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dispatch_bgmv"
,
&
dispatch_bgmv
,
"dispatch_bgmv"
);
m
.
def
(
"dispatch_bgmv_low_level"
,
&
dispatch_bgmv_low_level
,
"dispatch_bgmv_low_level"
);
}
Prev
1
2
3
4
5
6
7
…
18
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment