Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
xdb4_94051
vllm
Commits
e00b0a19
Commit
e00b0a19
authored
Mar 23, 2024
by
zhuwenwen
Browse files
merge v0.3.3
parents
ead94d93
3f1166ab
Changes
239
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2362 additions
and
4 deletions
+2362
-4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+61
-0
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+294
-0
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+27
-0
csrc/punica/bgmv/vec_dtypes.cuh
csrc/punica/bgmv/vec_dtypes.cuh
+1324
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+563
-0
csrc/pybind.cpp
csrc/pybind.cpp
+37
-4
No files found.
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_config.h
0 → 100644
View file @
e00b0a19
#pragma once
template
<
int
feat_in
,
int
feat_out
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
void
bgmv_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
// clang-format on
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_impl.cuh
0 → 100644
View file @
e00b0a19
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace
cg
=
cooperative_groups
;
// nthrs = (32, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
size_t
X_copy_size
,
size_t
W_copy_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_shrink_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
auto
block
=
cg
::
this_thread_block
();
size_t
j
=
blockIdx
.
x
;
constexpr
size_t
num_pipeline_stages
=
2
;
constexpr
size_t
tile_size
=
tx
*
ty
*
vec_size
;
__shared__
W_T
W_shared
[
num_pipeline_stages
*
tile_size
];
__shared__
in_T
X_shared
[
num_pipeline_stages
*
tile_size
];
__shared__
float
y_warpwise
[
ty
];
size_t
W_shared_offset
[
num_pipeline_stages
]
=
{
0U
,
1U
*
tile_size
};
size_t
X_shared_offset
[
num_pipeline_stages
]
=
{
0U
,
1U
*
tile_size
};
auto
pipe
=
cuda
::
make_pipeline
();
// pipeline load W/X and compute WX;
pipe
.
producer_acquire
();
cuda
::
memcpy_async
(
W_shared
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
W_copy_size
>
(
W_copy_size
),
pipe
);
cuda
::
memcpy_async
(
X_shared
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
X
+
(
batch_idx
*
feat_in
)
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
X_copy_size
>
(
X_copy_size
),
pipe
);
pipe
.
producer_commit
();
size_t
copy_idx
,
compute_idx
;
float
y
=
0.
f
;
vec_t
<
in_T
,
vec_size
>
x_vec
;
vec_t
<
W_T
,
vec_size
>
w_vec
;
size_t
tile_idx
;
#pragma unroll
for
(
tile_idx
=
1
;
tile_idx
<
(
feat_in
+
tile_size
-
1
)
/
tile_size
;
++
tile_idx
)
{
copy_idx
=
tile_idx
%
num_pipeline_stages
;
// pipeline stage: async copy W fragment
pipe
.
producer_acquire
();
if
(
tile_idx
*
tile_size
+
threadIdx
.
y
*
tx
*
vec_size
<
feat_in
)
{
cuda
::
memcpy_async
(
W_shared
+
W_shared_offset
[
copy_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
W_copy_size
>
(
W_copy_size
),
pipe
);
cuda
::
memcpy_async
(
X_shared
+
X_shared_offset
[
copy_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
X
+
(
batch_idx
*
feat_in
)
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
X_copy_size
>
(
X_copy_size
),
pipe
);
}
pipe
.
producer_commit
();
compute_idx
=
(
tile_idx
-
1
)
%
num_pipeline_stages
;
// pipeline stage: compute WX
pipe
.
consumer_wait
();
block
.
sync
();
x_vec
.
load
(
X_shared
+
X_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W_shared
+
W_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
offset
);
}
y_warpwise
[
threadIdx
.
y
]
=
sum
;
block
.
sync
();
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y
+=
y_warpwise
[
i
];
}
block
.
sync
();
pipe
.
consumer_release
();
}
compute_idx
=
(
tile_idx
-
1
)
%
num_pipeline_stages
;
// final pipeline stage
pipe
.
consumer_wait
();
block
.
sync
();
x_vec
.
load
(
X_shared
+
X_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W_shared
+
W_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
offset
);
}
y_warpwise
[
threadIdx
.
y
]
=
((
tile_idx
-
1
)
*
tile_size
+
threadIdx
.
y
*
tx
*
vec_size
<
feat_in
)
?
sum
:
0.
f
;
block
.
sync
();
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y
+=
y_warpwise
[
i
];
}
block
.
sync
();
pipe
.
consumer_release
();
// write Y;
if
(
block
.
thread_rank
()
==
0
)
{
Y
[
batch_idx
*
full_y_size
+
y_offset
+
j
]
+=
static_cast
<
out_T
>
(
y
);
}
}
// nthrs = (2, 16, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_expand_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
auto
block
=
cg
::
this_thread_block
();
size_t
tile_idx
=
blockIdx
.
x
;
// load X;
vec_t
<
in_T
,
vec_size
>
x_vec
;
x_vec
.
load
(
X
+
batch_idx
*
feat_in
+
threadIdx
.
x
*
vec_size
);
// load W;
vec_t
<
W_T
,
vec_size
>
w_vec
;
w_vec
.
load
(
W
+
(
idx
*
feat_out
+
tile_idx
*
tz
*
ty
)
*
feat_in
+
block
.
thread_rank
()
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
cg
::
thread_block_tile
g
=
cg
::
tiled_partition
<
tx
>
(
block
);
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
offset
);
}
sum
=
g
.
shfl
(
sum
,
0
);
if
(
threadIdx
.
x
==
0
)
{
Y
[
batch_idx
*
full_y_size
+
y_offset
+
tile_idx
*
(
tz
*
ty
)
+
threadIdx
.
z
*
ty
+
threadIdx
.
y
]
+=
static_cast
<
out_T
>
(
sum
);
}
}
template
<
int
feat_in
,
int
feat_out
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
void
bgmv_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
constexpr
size_t
vec_size
=
8
;
constexpr
int
tz
=
4
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
feat_in
<
feat_out
)
{
static_assert
(
feat_in
%
vec_size
==
0
);
constexpr
int
tx
=
feat_in
/
vec_size
;
static_assert
((
32
%
tx
==
0
&&
feat_out
%
(
32
/
tx
*
tz
)
==
0
)
||
(
16
%
tx
==
0
&&
feat_out
%
(
16
/
tx
*
tz
)
==
0
)
||
(
8
%
tx
==
0
&&
feat_out
%
(
8
/
tx
*
tz
)
==
0
));
if
constexpr
(
32
%
tx
==
0
&&
feat_out
%
(
32
/
tx
*
tz
)
==
0
)
{
constexpr
int
ty
=
32
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
(
16
%
tx
==
0
&&
feat_out
%
(
16
/
tx
*
tz
)
==
0
)
{
constexpr
int
ty
=
16
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
{
constexpr
int
ty
=
8
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
}
else
{
static_assert
(
feat_in
%
(
vec_size
*
32
)
==
0
||
feat_in
%
(
vec_size
*
16
)
==
0
||
feat_in
%
(
vec_size
*
8
)
==
0
);
if
constexpr
(
feat_in
%
(
vec_size
*
32
)
==
0
)
{
constexpr
int
tx
=
32
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
,
vec_size
*
sizeof
(
in_T
),
vec_size
*
sizeof
(
W_T
),
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
constexpr
(
feat_in
%
(
vec_size
/
2
*
32
)
==
0
)
{
constexpr
int
tx
=
32
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
/
2
,
vec_size
*
sizeof
(
in_T
)
/
2
,
vec_size
*
sizeof
(
W_T
)
/
2
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
constexpr
(
feat_in
%
(
vec_size
/
2
*
16
)
==
0
)
{
constexpr
int
tx
=
16
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
/
2
,
vec_size
*
sizeof
(
in_T
)
/
2
,
vec_size
*
sizeof
(
W_T
)
/
2
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
csrc/punica/bgmv/generator.py
0 → 100644
View file @
e00b0a19
DTYPES
=
[
"fp16"
,
"bf16"
,
"fp32"
]
DTYPE_MAP
=
{
"fp16"
:
"nv_half"
,
"bf16"
:
"nv_bfloat16"
,
"fp32"
:
"float"
,
}
TEMPLATE
=
"""
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
"""
.
lstrip
()
for
input_dtype
in
DTYPES
:
for
output_dtype
in
DTYPES
:
for
weight_dtype
in
DTYPES
:
if
weight_dtype
==
"fp32"
:
# FP32 weights are not supported.
continue
kernel_definition
=
TEMPLATE
.
format
(
input_dtype
=
DTYPE_MAP
[
input_dtype
],
output_dtype
=
DTYPE_MAP
[
output_dtype
],
weight_dtype
=
DTYPE_MAP
[
weight_dtype
])
filename
=
f
"bgmv_
{
input_dtype
}
_
{
output_dtype
}
_
{
weight_dtype
}
.cu"
with
open
(
filename
,
"w"
)
as
f
:
f
.
write
(
kernel_definition
)
csrc/punica/bgmv/vec_dtypes.cuh
0 → 100644
View file @
e00b0a19
This diff is collapsed.
Click to expand it.
csrc/punica/punica_ops.cc
0 → 100644
View file @
e00b0a19
This diff is collapsed.
Click to expand it.
csrc/pybind.cpp
View file @
e00b0a19
...
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
...
...
@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
// Quantization ops
#ifndef USE_ROCM
// Quantization ops
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
...
...
@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"
gather_cached_kv
"
,
&
gather_cached_kv
,
"
Ga
the
r
key and value
from the
cache
in
to
contiguous QKV tensors
"
);
"
convert_fp8_e5m2
"
,
&
convert_fp8_e5m2
,
"
Convert
the key and value cache to
fp8_e5m2 data type
"
);
// Cuda utils
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
...
...
@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11
::
module
custom_ar
=
m
.
def_submodule
(
"custom_ar"
,
"custom allreduce"
);
custom_ar
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init_custom_ar"
);
custom_ar
.
def
(
"should_custom_ar"
,
&
should_custom_ar
,
"should_custom_ar"
);
custom_ar
.
def
(
"all_reduce_reg"
,
&
all_reduce_reg
,
"all_reduce_reg"
);
custom_ar
.
def
(
"all_reduce_unreg"
,
&
all_reduce_unreg
,
"all_reduce_unreg"
);
custom_ar
.
def
(
"dispose"
,
&
dispose
,
"dispose"
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
,
"meta_size"
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
,
"register_buffer"
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
,
"get_graph_buffer_ipc_meta"
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"register_graph_buffers"
);
#endif
}
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment