Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e00b0a19
Commit
e00b0a19
authored
Mar 23, 2024
by
zhuwenwen
Browse files
merge v0.3.3
parents
ead94d93
3f1166ab
Changes
239
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2362 additions
and
4 deletions
+2362
-4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+61
-0
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+294
-0
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+27
-0
csrc/punica/bgmv/vec_dtypes.cuh
csrc/punica/bgmv/vec_dtypes.cuh
+1324
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+563
-0
csrc/pybind.cpp
csrc/pybind.cpp
+37
-4
No files found.
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_config.h
0 → 100644
View file @
e00b0a19
#pragma once
template
<
int
feat_in
,
int
feat_out
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
void
bgmv_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
// clang-format on
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_impl.cuh
0 → 100644
View file @
e00b0a19
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace
cg
=
cooperative_groups
;
// nthrs = (32, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
size_t
X_copy_size
,
size_t
W_copy_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_shrink_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
auto
block
=
cg
::
this_thread_block
();
size_t
j
=
blockIdx
.
x
;
constexpr
size_t
num_pipeline_stages
=
2
;
constexpr
size_t
tile_size
=
tx
*
ty
*
vec_size
;
__shared__
W_T
W_shared
[
num_pipeline_stages
*
tile_size
];
__shared__
in_T
X_shared
[
num_pipeline_stages
*
tile_size
];
__shared__
float
y_warpwise
[
ty
];
size_t
W_shared_offset
[
num_pipeline_stages
]
=
{
0U
,
1U
*
tile_size
};
size_t
X_shared_offset
[
num_pipeline_stages
]
=
{
0U
,
1U
*
tile_size
};
auto
pipe
=
cuda
::
make_pipeline
();
// pipeline load W/X and compute WX;
pipe
.
producer_acquire
();
cuda
::
memcpy_async
(
W_shared
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
W_copy_size
>
(
W_copy_size
),
pipe
);
cuda
::
memcpy_async
(
X_shared
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
X
+
(
batch_idx
*
feat_in
)
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
X_copy_size
>
(
X_copy_size
),
pipe
);
pipe
.
producer_commit
();
size_t
copy_idx
,
compute_idx
;
float
y
=
0.
f
;
vec_t
<
in_T
,
vec_size
>
x_vec
;
vec_t
<
W_T
,
vec_size
>
w_vec
;
size_t
tile_idx
;
#pragma unroll
for
(
tile_idx
=
1
;
tile_idx
<
(
feat_in
+
tile_size
-
1
)
/
tile_size
;
++
tile_idx
)
{
copy_idx
=
tile_idx
%
num_pipeline_stages
;
// pipeline stage: async copy W fragment
pipe
.
producer_acquire
();
if
(
tile_idx
*
tile_size
+
threadIdx
.
y
*
tx
*
vec_size
<
feat_in
)
{
cuda
::
memcpy_async
(
W_shared
+
W_shared_offset
[
copy_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
W_copy_size
>
(
W_copy_size
),
pipe
);
cuda
::
memcpy_async
(
X_shared
+
X_shared_offset
[
copy_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
X
+
(
batch_idx
*
feat_in
)
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
X_copy_size
>
(
X_copy_size
),
pipe
);
}
pipe
.
producer_commit
();
compute_idx
=
(
tile_idx
-
1
)
%
num_pipeline_stages
;
// pipeline stage: compute WX
pipe
.
consumer_wait
();
block
.
sync
();
x_vec
.
load
(
X_shared
+
X_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W_shared
+
W_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
offset
);
}
y_warpwise
[
threadIdx
.
y
]
=
sum
;
block
.
sync
();
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y
+=
y_warpwise
[
i
];
}
block
.
sync
();
pipe
.
consumer_release
();
}
compute_idx
=
(
tile_idx
-
1
)
%
num_pipeline_stages
;
// final pipeline stage
pipe
.
consumer_wait
();
block
.
sync
();
x_vec
.
load
(
X_shared
+
X_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W_shared
+
W_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
offset
);
}
y_warpwise
[
threadIdx
.
y
]
=
((
tile_idx
-
1
)
*
tile_size
+
threadIdx
.
y
*
tx
*
vec_size
<
feat_in
)
?
sum
:
0.
f
;
block
.
sync
();
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y
+=
y_warpwise
[
i
];
}
block
.
sync
();
pipe
.
consumer_release
();
// write Y;
if
(
block
.
thread_rank
()
==
0
)
{
Y
[
batch_idx
*
full_y_size
+
y_offset
+
j
]
+=
static_cast
<
out_T
>
(
y
);
}
}
// nthrs = (2, 16, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_expand_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
auto
block
=
cg
::
this_thread_block
();
size_t
tile_idx
=
blockIdx
.
x
;
// load X;
vec_t
<
in_T
,
vec_size
>
x_vec
;
x_vec
.
load
(
X
+
batch_idx
*
feat_in
+
threadIdx
.
x
*
vec_size
);
// load W;
vec_t
<
W_T
,
vec_size
>
w_vec
;
w_vec
.
load
(
W
+
(
idx
*
feat_out
+
tile_idx
*
tz
*
ty
)
*
feat_in
+
block
.
thread_rank
()
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
cg
::
thread_block_tile
g
=
cg
::
tiled_partition
<
tx
>
(
block
);
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
offset
);
}
sum
=
g
.
shfl
(
sum
,
0
);
if
(
threadIdx
.
x
==
0
)
{
Y
[
batch_idx
*
full_y_size
+
y_offset
+
tile_idx
*
(
tz
*
ty
)
+
threadIdx
.
z
*
ty
+
threadIdx
.
y
]
+=
static_cast
<
out_T
>
(
sum
);
}
}
template
<
int
feat_in
,
int
feat_out
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
void
bgmv_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
constexpr
size_t
vec_size
=
8
;
constexpr
int
tz
=
4
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
feat_in
<
feat_out
)
{
static_assert
(
feat_in
%
vec_size
==
0
);
constexpr
int
tx
=
feat_in
/
vec_size
;
static_assert
((
32
%
tx
==
0
&&
feat_out
%
(
32
/
tx
*
tz
)
==
0
)
||
(
16
%
tx
==
0
&&
feat_out
%
(
16
/
tx
*
tz
)
==
0
)
||
(
8
%
tx
==
0
&&
feat_out
%
(
8
/
tx
*
tz
)
==
0
));
if
constexpr
(
32
%
tx
==
0
&&
feat_out
%
(
32
/
tx
*
tz
)
==
0
)
{
constexpr
int
ty
=
32
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
(
16
%
tx
==
0
&&
feat_out
%
(
16
/
tx
*
tz
)
==
0
)
{
constexpr
int
ty
=
16
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
{
constexpr
int
ty
=
8
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
}
else
{
static_assert
(
feat_in
%
(
vec_size
*
32
)
==
0
||
feat_in
%
(
vec_size
*
16
)
==
0
||
feat_in
%
(
vec_size
*
8
)
==
0
);
if
constexpr
(
feat_in
%
(
vec_size
*
32
)
==
0
)
{
constexpr
int
tx
=
32
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
,
vec_size
*
sizeof
(
in_T
),
vec_size
*
sizeof
(
W_T
),
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
constexpr
(
feat_in
%
(
vec_size
/
2
*
32
)
==
0
)
{
constexpr
int
tx
=
32
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
/
2
,
vec_size
*
sizeof
(
in_T
)
/
2
,
vec_size
*
sizeof
(
W_T
)
/
2
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
constexpr
(
feat_in
%
(
vec_size
/
2
*
16
)
==
0
)
{
constexpr
int
tx
=
16
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
/
2
,
vec_size
*
sizeof
(
in_T
)
/
2
,
vec_size
*
sizeof
(
W_T
)
/
2
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
csrc/punica/bgmv/generator.py
0 → 100644
View file @
e00b0a19
DTYPES
=
[
"fp16"
,
"bf16"
,
"fp32"
]
DTYPE_MAP
=
{
"fp16"
:
"nv_half"
,
"bf16"
:
"nv_bfloat16"
,
"fp32"
:
"float"
,
}
TEMPLATE
=
"""
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
"""
.
lstrip
()
for
input_dtype
in
DTYPES
:
for
output_dtype
in
DTYPES
:
for
weight_dtype
in
DTYPES
:
if
weight_dtype
==
"fp32"
:
# FP32 weights are not supported.
continue
kernel_definition
=
TEMPLATE
.
format
(
input_dtype
=
DTYPE_MAP
[
input_dtype
],
output_dtype
=
DTYPE_MAP
[
output_dtype
],
weight_dtype
=
DTYPE_MAP
[
weight_dtype
])
filename
=
f
"bgmv_
{
input_dtype
}
_
{
output_dtype
}
_
{
weight_dtype
}
.cu"
with
open
(
filename
,
"w"
)
as
f
:
f
.
write
(
kernel_definition
)
csrc/punica/bgmv/vec_dtypes.cuh
0 → 100644
View file @
e00b0a19
This diff is collapsed.
Click to expand it.
csrc/punica/punica_ops.cc
0 → 100644
View file @
e00b0a19
This diff is collapsed.
Click to expand it.
csrc/pybind.cpp
View file @
e00b0a19
...
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
...
...
@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
// Quantization ops
#ifndef USE_ROCM
// Quantization ops
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
...
...
@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"
gather_cached_kv
"
,
&
gather_cached_kv
,
"
Ga
the
r
key and value
from the
cache
in
to
contiguous QKV tensors
"
);
"
convert_fp8_e5m2
"
,
&
convert_fp8_e5m2
,
"
Convert
the key and value cache to
fp8_e5m2 data type
"
);
// Cuda utils
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
...
...
@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11
::
module
custom_ar
=
m
.
def_submodule
(
"custom_ar"
,
"custom allreduce"
);
custom_ar
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init_custom_ar"
);
custom_ar
.
def
(
"should_custom_ar"
,
&
should_custom_ar
,
"should_custom_ar"
);
custom_ar
.
def
(
"all_reduce_reg"
,
&
all_reduce_reg
,
"all_reduce_reg"
);
custom_ar
.
def
(
"all_reduce_unreg"
,
&
all_reduce_unreg
,
"all_reduce_unreg"
);
custom_ar
.
def
(
"dispose"
,
&
dispose
,
"dispose"
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
,
"meta_size"
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
,
"register_buffer"
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
,
"get_graph_buffer_ipc_meta"
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"register_graph_buffers"
);
#endif
}
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment