Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
norm
vllm
Commits
e00b0a19
Commit
e00b0a19
authored
Mar 23, 2024
by
zhuwenwen
Browse files
merge v0.3.3
parents
ead94d93
3f1166ab
Changes
239
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2362 additions
and
4 deletions
+2362
-4
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_config.h
csrc/punica/bgmv/bgmv_config.h
+61
-0
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
+4
-0
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
+4
-0
csrc/punica/bgmv/bgmv_impl.cuh
csrc/punica/bgmv/bgmv_impl.cuh
+294
-0
csrc/punica/bgmv/generator.py
csrc/punica/bgmv/generator.py
+27
-0
csrc/punica/bgmv/vec_dtypes.cuh
csrc/punica/bgmv/vec_dtypes.cuh
+1324
-0
csrc/punica/punica_ops.cc
csrc/punica/punica_ops.cc
+563
-0
csrc/pybind.cpp
csrc/pybind.cpp
+37
-4
No files found.
csrc/punica/bgmv/bgmv_bf16_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_bf16_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_bfloat16
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_config.h
0 → 100644
View file @
e00b0a19
#pragma once
template
<
int
feat_in
,
int
feat_out
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
void
bgmv_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
);
// clang-format off
#define FOR_BGMV_WIDE(f, in_T, out_T, W_T, narrow) \
f(in_T, out_T, W_T, narrow, 128) \
f(in_T, out_T, W_T, narrow, 256) \
f(in_T, out_T, W_T, narrow, 512) \
f(in_T, out_T, W_T, narrow, 1024) \
f(in_T, out_T, W_T, narrow, 1280) \
f(in_T, out_T, W_T, narrow, 1728) \
f(in_T, out_T, W_T, narrow, 1792) \
f(in_T, out_T, W_T, narrow, 2048) \
f(in_T, out_T, W_T, narrow, 2560) \
f(in_T, out_T, W_T, narrow, 2752) \
f(in_T, out_T, W_T, narrow, 3072) \
f(in_T, out_T, W_T, narrow, 3456) \
f(in_T, out_T, W_T, narrow, 3584) \
f(in_T, out_T, W_T, narrow, 4096) \
f(in_T, out_T, W_T, narrow, 5120) \
f(in_T, out_T, W_T, narrow, 5504) \
f(in_T, out_T, W_T, narrow, 5632) \
f(in_T, out_T, W_T, narrow, 6144) \
f(in_T, out_T, W_T, narrow, 6912) \
f(in_T, out_T, W_T, narrow, 7168) \
f(in_T, out_T, W_T, narrow, 8192) \
f(in_T, out_T, W_T, narrow, 9216) \
f(in_T, out_T, W_T, narrow, 10240) \
f(in_T, out_T, W_T, narrow, 11008) \
f(in_T, out_T, W_T, narrow, 12288) \
f(in_T, out_T, W_T, narrow, 13824) \
f(in_T, out_T, W_T, narrow, 14336) \
f(in_T, out_T, W_T, narrow, 16384) \
f(in_T, out_T, W_T, narrow, 20480) \
f(in_T, out_T, W_T, narrow, 24576) \
f(in_T, out_T, W_T, narrow, 28672) \
f(in_T, out_T, W_T, narrow, 32000) \
f(in_T, out_T, W_T, narrow, 32256) \
f(in_T, out_T, W_T, narrow, 32512) \
f(in_T, out_T, W_T, narrow, 32768) \
f(in_T, out_T, W_T, narrow, 33024) \
f(in_T, out_T, W_T, narrow, 36864) \
f(in_T, out_T, W_T, narrow, 49152) \
// Keep above in sync with vllm/lora/layers::SamplerWithLoRA
// Keep this in sync with vllm/config::LoRAConfig
#define FOR_BGMV_WIDE_NARROW(f, in_T, out_T, W_T) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 8) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 16) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 32) \
FOR_BGMV_WIDE(f, in_T, out_T, W_T, 64)
// clang-format on
csrc/punica/bgmv/bgmv_fp16_bf16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_bf16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp16_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp16_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
nv_half
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_bf16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_bf16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_bfloat16
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp16_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp16_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
nv_half
,
nv_half
)
csrc/punica/bgmv/bgmv_fp32_fp32_bf16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_bfloat16
)
csrc/punica/bgmv/bgmv_fp32_fp32_fp16.cu
0 → 100644
View file @
e00b0a19
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW
(
INST_BGMV_TWOSIDE
,
float
,
float
,
nv_half
)
csrc/punica/bgmv/bgmv_impl.cuh
0 → 100644
View file @
e00b0a19
#pragma once
#include <ATen/cuda/CUDAContext.h>
#include <cooperative_groups.h>
#include <cuda/pipeline>
#include <cuda_runtime.h>
#include <iostream>
#include <stdio.h>
#include "vec_dtypes.cuh"
namespace
cg
=
cooperative_groups
;
// nthrs = (32, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
size_t
X_copy_size
,
size_t
W_copy_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_shrink_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
auto
block
=
cg
::
this_thread_block
();
size_t
j
=
blockIdx
.
x
;
constexpr
size_t
num_pipeline_stages
=
2
;
constexpr
size_t
tile_size
=
tx
*
ty
*
vec_size
;
__shared__
W_T
W_shared
[
num_pipeline_stages
*
tile_size
];
__shared__
in_T
X_shared
[
num_pipeline_stages
*
tile_size
];
__shared__
float
y_warpwise
[
ty
];
size_t
W_shared_offset
[
num_pipeline_stages
]
=
{
0U
,
1U
*
tile_size
};
size_t
X_shared_offset
[
num_pipeline_stages
]
=
{
0U
,
1U
*
tile_size
};
auto
pipe
=
cuda
::
make_pipeline
();
// pipeline load W/X and compute WX;
pipe
.
producer_acquire
();
cuda
::
memcpy_async
(
W_shared
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
W_copy_size
>
(
W_copy_size
),
pipe
);
cuda
::
memcpy_async
(
X_shared
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
X
+
(
batch_idx
*
feat_in
)
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
X_copy_size
>
(
X_copy_size
),
pipe
);
pipe
.
producer_commit
();
size_t
copy_idx
,
compute_idx
;
float
y
=
0.
f
;
vec_t
<
in_T
,
vec_size
>
x_vec
;
vec_t
<
W_T
,
vec_size
>
w_vec
;
size_t
tile_idx
;
#pragma unroll
for
(
tile_idx
=
1
;
tile_idx
<
(
feat_in
+
tile_size
-
1
)
/
tile_size
;
++
tile_idx
)
{
copy_idx
=
tile_idx
%
num_pipeline_stages
;
// pipeline stage: async copy W fragment
pipe
.
producer_acquire
();
if
(
tile_idx
*
tile_size
+
threadIdx
.
y
*
tx
*
vec_size
<
feat_in
)
{
cuda
::
memcpy_async
(
W_shared
+
W_shared_offset
[
copy_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
W
+
(
idx
*
feat_out
+
j
)
*
feat_in
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
W_copy_size
>
(
W_copy_size
),
pipe
);
cuda
::
memcpy_async
(
X_shared
+
X_shared_offset
[
copy_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
X
+
(
batch_idx
*
feat_in
)
+
tile_idx
*
tile_size
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
,
cuda
::
aligned_size_t
<
X_copy_size
>
(
X_copy_size
),
pipe
);
}
pipe
.
producer_commit
();
compute_idx
=
(
tile_idx
-
1
)
%
num_pipeline_stages
;
// pipeline stage: compute WX
pipe
.
consumer_wait
();
block
.
sync
();
x_vec
.
load
(
X_shared
+
X_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W_shared
+
W_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
offset
);
}
y_warpwise
[
threadIdx
.
y
]
=
sum
;
block
.
sync
();
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y
+=
y_warpwise
[
i
];
}
block
.
sync
();
pipe
.
consumer_release
();
}
compute_idx
=
(
tile_idx
-
1
)
%
num_pipeline_stages
;
// final pipeline stage
pipe
.
consumer_wait
();
block
.
sync
();
x_vec
.
load
(
X_shared
+
X_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
w_vec
.
load
(
W_shared
+
W_shared_offset
[
compute_idx
]
+
(
threadIdx
.
y
*
tx
+
threadIdx
.
x
)
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
__shfl_down_sync
(
0xffffffff
,
sum
,
offset
);
}
y_warpwise
[
threadIdx
.
y
]
=
((
tile_idx
-
1
)
*
tile_size
+
threadIdx
.
y
*
tx
*
vec_size
<
feat_in
)
?
sum
:
0.
f
;
block
.
sync
();
#pragma unroll
for
(
size_t
i
=
0
;
i
<
ty
;
++
i
)
{
y
+=
y_warpwise
[
i
];
}
block
.
sync
();
pipe
.
consumer_release
();
// write Y;
if
(
block
.
thread_rank
()
==
0
)
{
Y
[
batch_idx
*
full_y_size
+
y_offset
+
j
]
+=
static_cast
<
out_T
>
(
y
);
}
}
// nthrs = (2, 16, 4)
template
<
int
feat_in
,
int
feat_out
,
size_t
vec_size
,
int
tx
,
int
ty
,
int
tz
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
__global__
void
bgmv_expand_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
size_t
batch_idx
=
blockIdx
.
y
;
int64_t
idx
=
indicies
[
batch_idx
]
*
num_layers
+
layer_idx
;
if
(
idx
<
0
)
{
return
;
}
auto
block
=
cg
::
this_thread_block
();
size_t
tile_idx
=
blockIdx
.
x
;
// load X;
vec_t
<
in_T
,
vec_size
>
x_vec
;
x_vec
.
load
(
X
+
batch_idx
*
feat_in
+
threadIdx
.
x
*
vec_size
);
// load W;
vec_t
<
W_T
,
vec_size
>
w_vec
;
w_vec
.
load
(
W
+
(
idx
*
feat_out
+
tile_idx
*
tz
*
ty
)
*
feat_in
+
block
.
thread_rank
()
*
vec_size
);
float
sum
=
0.
f
;
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
sum
+=
float
(
w_vec
[
i
])
*
float
(
x_vec
[
i
])
*
scale
;
}
cg
::
thread_block_tile
g
=
cg
::
tiled_partition
<
tx
>
(
block
);
#pragma unroll
for
(
size_t
offset
=
tx
/
2
;
offset
>
0
;
offset
/=
2
)
{
sum
+=
g
.
shfl_down
(
sum
,
offset
);
}
sum
=
g
.
shfl
(
sum
,
0
);
if
(
threadIdx
.
x
==
0
)
{
Y
[
batch_idx
*
full_y_size
+
y_offset
+
tile_idx
*
(
tz
*
ty
)
+
threadIdx
.
z
*
ty
+
threadIdx
.
y
]
+=
static_cast
<
out_T
>
(
sum
);
}
}
template
<
int
feat_in
,
int
feat_out
,
typename
in_T
,
typename
out_T
,
typename
W_T
>
void
bgmv_kernel
(
out_T
*
__restrict__
Y
,
const
in_T
*
__restrict__
X
,
const
W_T
*
__restrict__
W
,
const
int64_t
*
__restrict__
indicies
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
constexpr
size_t
vec_size
=
8
;
constexpr
int
tz
=
4
;
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
constexpr
(
feat_in
<
feat_out
)
{
static_assert
(
feat_in
%
vec_size
==
0
);
constexpr
int
tx
=
feat_in
/
vec_size
;
static_assert
((
32
%
tx
==
0
&&
feat_out
%
(
32
/
tx
*
tz
)
==
0
)
||
(
16
%
tx
==
0
&&
feat_out
%
(
16
/
tx
*
tz
)
==
0
)
||
(
8
%
tx
==
0
&&
feat_out
%
(
8
/
tx
*
tz
)
==
0
));
if
constexpr
(
32
%
tx
==
0
&&
feat_out
%
(
32
/
tx
*
tz
)
==
0
)
{
constexpr
int
ty
=
32
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
(
16
%
tx
==
0
&&
feat_out
%
(
16
/
tx
*
tz
)
==
0
)
{
constexpr
int
ty
=
16
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
{
constexpr
int
ty
=
8
/
tx
;
dim3
nblks
(
feat_out
/
(
ty
*
tz
),
batch_size
);
dim3
nthrs
(
tx
,
ty
,
tz
);
bgmv_expand_kernel
<
feat_in
,
feat_out
,
vec_size
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
}
else
{
static_assert
(
feat_in
%
(
vec_size
*
32
)
==
0
||
feat_in
%
(
vec_size
*
16
)
==
0
||
feat_in
%
(
vec_size
*
8
)
==
0
);
if
constexpr
(
feat_in
%
(
vec_size
*
32
)
==
0
)
{
constexpr
int
tx
=
32
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
,
vec_size
*
sizeof
(
in_T
),
vec_size
*
sizeof
(
W_T
),
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
constexpr
(
feat_in
%
(
vec_size
/
2
*
32
)
==
0
)
{
constexpr
int
tx
=
32
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
/
2
,
vec_size
*
sizeof
(
in_T
)
/
2
,
vec_size
*
sizeof
(
W_T
)
/
2
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
else
if
constexpr
(
feat_in
%
(
vec_size
/
2
*
16
)
==
0
)
{
constexpr
int
tx
=
16
;
constexpr
int
ty
=
4
;
dim3
nblks
(
feat_out
,
batch_size
);
dim3
nthrs
(
tx
,
ty
);
bgmv_shrink_kernel
<
feat_in
,
feat_out
,
vec_size
/
2
,
vec_size
*
sizeof
(
in_T
)
/
2
,
vec_size
*
sizeof
(
W_T
)
/
2
,
tx
,
ty
,
tz
>
<<<
nblks
,
nthrs
,
0
,
stream
>>>
(
Y
,
X
,
W
,
indicies
,
y_offset
,
full_y_size
,
num_layers
,
layer_idx
,
scale
);
}
}
}
#define INST_BGMV(feat_in, feat_out, in_T, out_T, W_T) \
template void bgmv_kernel<feat_in, feat_out>( \
out_T * __restrict__ Y, const in_T *__restrict__ X, \
const W_T *__restrict__ W, const int64_t *__restrict__ indicies, \
int64_t y_offset, int64_t full_y_size, int64_t batch_size, \
int64_t num_layers, int64_t layer_idx, float scale);
#define INST_BGMV_TWOSIDE(in_T, out_T, W_T, narrow, wide) \
INST_BGMV(narrow, wide, in_T, out_T, W_T) \
INST_BGMV(wide, narrow, in_T, out_T, W_T)
csrc/punica/bgmv/generator.py
0 → 100644
View file @
e00b0a19
DTYPES
=
[
"fp16"
,
"bf16"
,
"fp32"
]
DTYPE_MAP
=
{
"fp16"
:
"nv_half"
,
"bf16"
:
"nv_bfloat16"
,
"fp32"
:
"float"
,
}
TEMPLATE
=
"""
#include "bgmv_config.h"
#include "bgmv_impl.cuh"
FOR_BGMV_WIDE_NARROW(INST_BGMV_TWOSIDE, {input_dtype}, {output_dtype}, {weight_dtype})
"""
.
lstrip
()
for
input_dtype
in
DTYPES
:
for
output_dtype
in
DTYPES
:
for
weight_dtype
in
DTYPES
:
if
weight_dtype
==
"fp32"
:
# FP32 weights are not supported.
continue
kernel_definition
=
TEMPLATE
.
format
(
input_dtype
=
DTYPE_MAP
[
input_dtype
],
output_dtype
=
DTYPE_MAP
[
output_dtype
],
weight_dtype
=
DTYPE_MAP
[
weight_dtype
])
filename
=
f
"bgmv_
{
input_dtype
}
_
{
output_dtype
}
_
{
weight_dtype
}
.cu"
with
open
(
filename
,
"w"
)
as
f
:
f
.
write
(
kernel_definition
)
csrc/punica/bgmv/vec_dtypes.cuh
0 → 100644
View file @
e00b0a19
#ifndef VEC_DTYPES_CUH_
#define VEC_DTYPES_CUH_
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#ifdef FLASHINFER_USE_FP8
#include <cuda_fp8.h>
#endif
#include <cuda_runtime.h>
#include <type_traits>
#define FLASHINFER_INLINE \
inline __attribute__((always_inline)) __device__ __host__
template
<
typename
float_t
,
size_t
vec_size
>
struct
vec_t
{
FLASHINFER_INLINE
float_t
&
operator
[](
size_t
i
);
FLASHINFER_INLINE
const
float_t
&
operator
[](
size_t
i
)
const
;
FLASHINFER_INLINE
void
fill
(
float_t
val
);
FLASHINFER_INLINE
void
load
(
const
float_t
*
ptr
);
FLASHINFER_INLINE
void
store
(
float_t
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
vec_size
>
&
src
);
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
);
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
;
FLASHINFER_INLINE
static
void
memcpy
(
float_t
*
dst
,
const
float_t
*
src
);
};
template
<
typename
src_float_t
,
typename
tgt_float_t
,
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
src_float_t
,
vec_size
>
&
src
,
vec_t
<
tgt_float_t
,
vec_size
>
&
dst
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
dst
[
i
]
=
tgt_float_t
(
src
[
i
]);
}
}
template
<
typename
src_float_t
,
typename
tgt_float_t
,
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_load_impl
(
const
src_float_t
*
src_ptr
,
vec_t
<
tgt_float_t
,
vec_size
>
&
dst
)
{
if
constexpr
(
std
::
is_same
<
src_float_t
,
tgt_float_t
>::
value
)
{
dst
.
load
(
src_ptr
);
}
else
{
vec_t
<
src_float_t
,
vec_size
>
tmp
;
tmp
.
load
(
src_ptr
);
dst
.
cast_from
(
tmp
);
}
}
template
<
typename
src_float_t
,
typename
tgt_float_t
,
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_store_impl
(
const
vec_t
<
src_float_t
,
vec_size
>
&
src
,
tgt_float_t
*
dst_ptr
)
{
if
constexpr
(
std
::
is_same
<
src_float_t
,
tgt_float_t
>::
value
)
{
src
.
store
(
dst_ptr
);
}
else
{
vec_t
<
tgt_float_t
,
vec_size
>
tmp
;
tmp
.
cast_from
(
src
);
tmp
.
store
(
dst_ptr
);
}
}
#ifdef FLASHINFER_USE_FP8
/******************* vec_t<__nv_fp8_e4m3> *******************/
// __nv_fp8_e4m3 x 1
template
<
>
struct
vec_t
<
__nv_fp8_e4m3
,
1
>
{
__nv_fp8_e4m3
data
;
FLASHINFER_INLINE
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e4m3
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e4m3
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e4m3
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
1
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
1
>::
fill
(
__nv_fp8_e4m3
val
)
{
data
=
val
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
1
>::
load
(
const
__nv_fp8_e4m3
*
ptr
)
{
data
=
*
ptr
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
1
>::
store
(
__nv_fp8_e4m3
*
ptr
)
const
{
*
ptr
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
1
>::
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
)
{
*
dst
=
*
src
;
}
// __nv_fp8_e4m3 x 2
template
<
>
struct
vec_t
<
__nv_fp8_e4m3
,
2
>
{
__nv_fp8x2_e4m3
data
;
FLASHINFER_INLINE
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e4m3
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e4m3
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e4m3
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
2
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
2
>::
fill
(
__nv_fp8_e4m3
val
)
{
data
.
__x
=
(
__nv_fp8x2_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x2_storage_t
(
val
.
__x
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
2
>::
load
(
const
__nv_fp8_e4m3
*
ptr
)
{
data
=
*
((
__nv_fp8x2_e4m3
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
2
>::
store
(
__nv_fp8_e4m3
*
ptr
)
const
{
*
((
__nv_fp8x2_e4m3
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
2
>::
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
)
{
*
((
__nv_fp8x2_e4m3
*
)
dst
)
=
*
((
__nv_fp8x2_e4m3
*
)
src
);
}
// __nv_fp8_e4m3 x 4
template
<
>
struct
vec_t
<
__nv_fp8_e4m3
,
4
>
{
__nv_fp8x4_e4m3
data
;
FLASHINFER_INLINE
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e4m3
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e4m3
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e4m3
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
4
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
4
>::
fill
(
__nv_fp8_e4m3
val
)
{
data
.
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
4
>::
load
(
const
__nv_fp8_e4m3
*
ptr
)
{
data
=
*
((
__nv_fp8x4_e4m3
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
4
>::
store
(
__nv_fp8_e4m3
*
ptr
)
const
{
*
((
__nv_fp8x4_e4m3
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
4
>::
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
)
{
*
((
__nv_fp8x4_e4m3
*
)
dst
)
=
*
((
__nv_fp8x4_e4m3
*
)
src
);
}
// __nv_fp8_e4m3 x 8
template
<
>
struct
vec_t
<
__nv_fp8_e4m3
,
8
>
{
uint2
data
;
FLASHINFER_INLINE
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e4m3
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e4m3
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e4m3
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e4m3
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
8
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
8
>::
fill
(
__nv_fp8_e4m3
val
)
{
((
__nv_fp8x4_e4m3
*
)(
&
data
.
x
))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e4m3
*
)(
&
data
.
y
))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
8
>::
load
(
const
__nv_fp8_e4m3
*
ptr
)
{
data
=
*
((
uint2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
8
>::
store
(
__nv_fp8_e4m3
*
ptr
)
const
{
*
((
uint2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e4m3
,
8
>::
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
)
{
*
((
__nv_fp8_e4m3
*
)
dst
)
=
*
((
__nv_fp8_e4m3
*
)
src
);
}
// __nv_fp8_e4m3 x 16 or more
template
<
size_t
vec_size
>
struct
vec_t
<
__nv_fp8_e4m3
,
vec_size
>
{
uint4
data
[
vec_size
/
16
];
FLASHINFER_INLINE
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e4m3
*
)
data
)[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e4m3
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e4m3
*
)
data
)[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e4m3
val
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
((
__nv_fp8x4_e4m3
*
)(
&
(
data
[
i
].
x
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e4m3
*
)(
&
(
data
[
i
].
y
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e4m3
*
)(
&
(
data
[
i
].
z
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e4m3
*
)(
&
(
data
[
i
].
w
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
}
}
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e4m3
*
ptr
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
data
[
i
]
=
((
uint4
*
)
ptr
)[
i
];
}
}
FLASHINFER_INLINE
void
store
(
__nv_fp8_e4m3
*
ptr
)
const
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
((
uint4
*
)
ptr
)[
i
]
=
data
[
i
];
}
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
vec_size
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e4m3
*
dst
,
const
__nv_fp8_e4m3
*
src
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
((
uint4
*
)
dst
)[
i
]
=
((
uint4
*
)
src
)[
i
];
}
}
};
/******************* vec_t<__nv_fp8_e5m2> *******************/
// __nv_fp8_e5m2 x 1
template
<
>
struct
vec_t
<
__nv_fp8_e5m2
,
1
>
{
__nv_fp8_e5m2
data
;
FLASHINFER_INLINE
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e5m2
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e5m2
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e5m2
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
1
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
1
>::
fill
(
__nv_fp8_e5m2
val
)
{
data
=
val
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
1
>::
load
(
const
__nv_fp8_e5m2
*
ptr
)
{
data
=
*
ptr
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
1
>::
store
(
__nv_fp8_e5m2
*
ptr
)
const
{
*
ptr
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
1
>::
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
)
{
*
dst
=
*
src
;
}
// __nv_fp8_e5m2 x 2
template
<
>
struct
vec_t
<
__nv_fp8_e5m2
,
2
>
{
__nv_fp8x2_e5m2
data
;
FLASHINFER_INLINE
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e5m2
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e5m2
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e5m2
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
2
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
2
>::
fill
(
__nv_fp8_e5m2
val
)
{
data
.
__x
=
(
__nv_fp8x2_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x2_storage_t
(
val
.
__x
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
2
>::
load
(
const
__nv_fp8_e5m2
*
ptr
)
{
data
=
*
((
__nv_fp8x2_e5m2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
2
>::
store
(
__nv_fp8_e5m2
*
ptr
)
const
{
*
((
__nv_fp8x2_e5m2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
2
>::
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
)
{
*
((
__nv_fp8x2_e5m2
*
)
dst
)
=
*
((
__nv_fp8x2_e5m2
*
)
src
);
}
// __nv_fp8_e5m2 x 4
template
<
>
struct
vec_t
<
__nv_fp8_e5m2
,
4
>
{
__nv_fp8x4_e5m2
data
;
FLASHINFER_INLINE
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e5m2
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e5m2
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e5m2
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
4
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
4
>::
fill
(
__nv_fp8_e5m2
val
)
{
data
.
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
4
>::
load
(
const
__nv_fp8_e5m2
*
ptr
)
{
data
=
*
((
__nv_fp8x4_e5m2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
4
>::
store
(
__nv_fp8_e5m2
*
ptr
)
const
{
*
((
__nv_fp8x4_e5m2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
4
>::
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
)
{
*
((
__nv_fp8x4_e5m2
*
)
dst
)
=
*
((
__nv_fp8x4_e5m2
*
)
src
);
}
// __nv_fp8_e5m2 x 8
template
<
>
struct
vec_t
<
__nv_fp8_e5m2
,
8
>
{
uint2
data
;
FLASHINFER_INLINE
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e5m2
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e5m2
val
);
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e5m2
*
ptr
);
FLASHINFER_INLINE
void
store
(
__nv_fp8_e5m2
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
8
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
8
>::
fill
(
__nv_fp8_e5m2
val
)
{
((
__nv_fp8x4_e5m2
*
)(
&
data
.
x
))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e5m2
*
)(
&
data
.
y
))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
8
>::
load
(
const
__nv_fp8_e5m2
*
ptr
)
{
data
=
*
((
uint2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
8
>::
store
(
__nv_fp8_e5m2
*
ptr
)
const
{
*
((
uint2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
__nv_fp8_e5m2
,
8
>::
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
)
{
*
((
__nv_fp8_e5m2
*
)
dst
)
=
*
((
__nv_fp8_e5m2
*
)
src
);
}
// __nv_fp8_e5m2 x 16 or more
template
<
size_t
vec_size
>
struct
vec_t
<
__nv_fp8_e5m2
,
vec_size
>
{
uint4
data
[
vec_size
/
16
];
FLASHINFER_INLINE
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
{
return
((
__nv_fp8_e5m2
*
)
data
)[
i
];
}
FLASHINFER_INLINE
const
__nv_fp8_e5m2
&
operator
[](
size_t
i
)
const
{
return
((
const
__nv_fp8_e5m2
*
)
data
)[
i
];
}
FLASHINFER_INLINE
void
fill
(
__nv_fp8_e5m2
val
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
((
__nv_fp8x4_e5m2
*
)(
&
(
data
[
i
].
x
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e5m2
*
)(
&
(
data
[
i
].
y
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e5m2
*
)(
&
(
data
[
i
].
z
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
((
__nv_fp8x4_e5m2
*
)(
&
(
data
[
i
].
w
)))
->
__x
=
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
24
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
16
)
|
(
__nv_fp8x4_storage_t
(
val
.
__x
)
<<
8
)
|
__nv_fp8x4_storage_t
(
val
.
__x
);
}
}
FLASHINFER_INLINE
void
load
(
const
__nv_fp8_e5m2
*
ptr
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
data
[
i
]
=
((
uint4
*
)
ptr
)[
i
];
}
}
FLASHINFER_INLINE
void
store
(
__nv_fp8_e5m2
*
ptr
)
const
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
((
uint4
*
)
ptr
)[
i
]
=
data
[
i
];
}
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
vec_size
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
__nv_fp8_e5m2
*
dst
,
const
__nv_fp8_e5m2
*
src
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
16
;
++
i
)
{
((
uint4
*
)
dst
)[
i
]
=
((
uint4
*
)
src
)[
i
];
}
}
};
#endif
/******************* vec_t<half> *******************/
// half x 1
template
<
>
struct
vec_t
<
half
,
1
>
{
half
data
;
FLASHINFER_INLINE
half
&
operator
[](
size_t
i
)
{
return
((
half
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
half
&
operator
[](
size_t
i
)
const
{
return
((
const
half
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
half
val
);
FLASHINFER_INLINE
void
load
(
const
half
*
ptr
);
FLASHINFER_INLINE
void
store
(
half
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
1
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
half
*
dst
,
const
half
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
half
,
1
>::
fill
(
half
val
)
{
data
=
val
;
}
FLASHINFER_INLINE
void
vec_t
<
half
,
1
>::
load
(
const
half
*
ptr
)
{
data
=
*
ptr
;
}
FLASHINFER_INLINE
void
vec_t
<
half
,
1
>::
store
(
half
*
ptr
)
const
{
*
ptr
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
half
,
1
>::
memcpy
(
half
*
dst
,
const
half
*
src
)
{
*
dst
=
*
src
;
}
// half x 2
template
<
>
struct
vec_t
<
half
,
2
>
{
half2
data
;
FLASHINFER_INLINE
half
&
operator
[](
size_t
i
)
{
return
((
half
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
half
&
operator
[](
size_t
i
)
const
{
return
((
const
half
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
half
val
);
FLASHINFER_INLINE
void
load
(
const
half
*
ptr
);
FLASHINFER_INLINE
void
store
(
half
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
2
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
half
*
dst
,
const
half
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
half
,
2
>::
fill
(
half
val
)
{
data
=
make_half2
(
val
,
val
);
}
FLASHINFER_INLINE
void
vec_t
<
half
,
2
>::
load
(
const
half
*
ptr
)
{
data
=
*
((
half2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
half
,
2
>::
store
(
half
*
ptr
)
const
{
*
((
half2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
half
,
2
>::
memcpy
(
half
*
dst
,
const
half
*
src
)
{
*
((
half2
*
)
dst
)
=
*
((
half2
*
)
src
);
}
// half x 4
template
<
>
struct
vec_t
<
half
,
4
>
{
uint2
data
;
FLASHINFER_INLINE
half
&
operator
[](
size_t
i
)
{
return
((
half
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
half
&
operator
[](
size_t
i
)
const
{
return
((
const
half
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
half
val
);
FLASHINFER_INLINE
void
load
(
const
half
*
ptr
);
FLASHINFER_INLINE
void
store
(
half
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
4
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
half
*
dst
,
const
half
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
half
,
4
>::
fill
(
half
val
)
{
*
(
half2
*
)(
&
data
.
x
)
=
make_half2
(
val
,
val
);
*
(
half2
*
)(
&
data
.
y
)
=
make_half2
(
val
,
val
);
}
FLASHINFER_INLINE
void
vec_t
<
half
,
4
>::
load
(
const
half
*
ptr
)
{
data
=
*
((
uint2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
half
,
4
>::
store
(
half
*
ptr
)
const
{
*
((
uint2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
half
,
4
>::
memcpy
(
half
*
dst
,
const
half
*
src
)
{
*
((
uint2
*
)
dst
)
=
*
((
uint2
*
)
src
);
}
// half x 8 or more
template
<
size_t
vec_size
>
struct
vec_t
<
half
,
vec_size
>
{
uint4
data
[
vec_size
/
8
];
FLASHINFER_INLINE
half
&
operator
[](
size_t
i
)
{
return
((
half
*
)
data
)[
i
];
}
FLASHINFER_INLINE
const
half
&
operator
[](
size_t
i
)
const
{
return
((
const
half
*
)
data
)[
i
];
}
FLASHINFER_INLINE
void
fill
(
half
val
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
;
++
i
)
{
*
(
half2
*
)(
&
(
data
[
i
].
x
))
=
make_half2
(
val
,
val
);
*
(
half2
*
)(
&
(
data
[
i
].
y
))
=
make_half2
(
val
,
val
);
*
(
half2
*
)(
&
(
data
[
i
].
z
))
=
make_half2
(
val
,
val
);
*
(
half2
*
)(
&
(
data
[
i
].
w
))
=
make_half2
(
val
,
val
);
}
}
FLASHINFER_INLINE
void
load
(
const
half
*
ptr
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
data
[
i
]
=
((
uint4
*
)
ptr
)[
i
];
}
}
FLASHINFER_INLINE
void
store
(
half
*
ptr
)
const
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
((
uint4
*
)
ptr
)[
i
]
=
data
[
i
];
}
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
vec_size
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
half
*
dst
,
const
half
*
src
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
((
uint4
*
)
dst
)[
i
]
=
((
uint4
*
)
src
)[
i
];
}
}
};
/******************* vec_t<nv_bfloat16> *******************/
// nv_bfloat16 x 1
template
<
>
struct
vec_t
<
nv_bfloat16
,
1
>
{
nv_bfloat16
data
;
FLASHINFER_INLINE
nv_bfloat16
&
operator
[](
size_t
i
)
{
return
((
nv_bfloat16
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
nv_bfloat16
&
operator
[](
size_t
i
)
const
{
return
((
const
nv_bfloat16
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
nv_bfloat16
val
);
FLASHINFER_INLINE
void
load
(
const
nv_bfloat16
*
ptr
);
FLASHINFER_INLINE
void
store
(
nv_bfloat16
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
1
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
1
>::
fill
(
nv_bfloat16
val
)
{
data
=
val
;
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
1
>::
load
(
const
nv_bfloat16
*
ptr
)
{
data
=
*
ptr
;
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
1
>::
store
(
nv_bfloat16
*
ptr
)
const
{
*
ptr
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
1
>::
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
)
{
*
dst
=
*
src
;
}
// nv_bfloat16 x 2
template
<
>
struct
vec_t
<
nv_bfloat16
,
2
>
{
nv_bfloat162
data
;
FLASHINFER_INLINE
nv_bfloat16
&
operator
[](
size_t
i
)
{
return
((
nv_bfloat16
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
nv_bfloat16
&
operator
[](
size_t
i
)
const
{
return
((
const
nv_bfloat16
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
nv_bfloat16
val
);
FLASHINFER_INLINE
void
load
(
const
nv_bfloat16
*
ptr
);
FLASHINFER_INLINE
void
store
(
nv_bfloat16
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
2
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
2
>::
fill
(
nv_bfloat16
val
)
{
data
=
make_bfloat162
(
val
,
val
);
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
2
>::
load
(
const
nv_bfloat16
*
ptr
)
{
data
=
*
((
nv_bfloat162
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
2
>::
store
(
nv_bfloat16
*
ptr
)
const
{
*
((
nv_bfloat162
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
2
>::
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
)
{
*
((
nv_bfloat162
*
)
dst
)
=
*
((
nv_bfloat162
*
)
src
);
}
// nv_bfloat16 x 4
template
<
>
struct
vec_t
<
nv_bfloat16
,
4
>
{
uint2
data
;
FLASHINFER_INLINE
nv_bfloat16
&
operator
[](
size_t
i
)
{
return
((
nv_bfloat16
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
nv_bfloat16
&
operator
[](
size_t
i
)
const
{
return
((
const
nv_bfloat16
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
nv_bfloat16
val
);
FLASHINFER_INLINE
void
load
(
const
nv_bfloat16
*
ptr
);
FLASHINFER_INLINE
void
store
(
nv_bfloat16
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
4
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
4
>::
fill
(
nv_bfloat16
val
)
{
*
(
nv_bfloat162
*
)(
&
data
.
x
)
=
make_bfloat162
(
val
,
val
);
*
(
nv_bfloat162
*
)(
&
data
.
y
)
=
make_bfloat162
(
val
,
val
);
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
4
>::
load
(
const
nv_bfloat16
*
ptr
)
{
data
=
*
((
uint2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
4
>::
store
(
nv_bfloat16
*
ptr
)
const
{
*
((
uint2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
nv_bfloat16
,
4
>::
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
)
{
*
((
uint2
*
)
dst
)
=
*
((
uint2
*
)
src
);
}
// nv_bfloat16 x 8 or more
template
<
size_t
vec_size
>
struct
vec_t
<
nv_bfloat16
,
vec_size
>
{
uint4
data
[
vec_size
/
8
];
FLASHINFER_INLINE
nv_bfloat16
&
operator
[](
size_t
i
)
{
return
((
nv_bfloat16
*
)
data
)[
i
];
}
FLASHINFER_INLINE
const
nv_bfloat16
&
operator
[](
size_t
i
)
const
{
return
((
const
nv_bfloat16
*
)
data
)[
i
];
}
FLASHINFER_INLINE
void
fill
(
nv_bfloat16
val
)
{
#pragma unoll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
*
(
nv_bfloat162
*
)(
&
(
data
[
i
].
x
))
=
make_bfloat162
(
val
,
val
);
*
(
nv_bfloat162
*
)(
&
(
data
[
i
].
y
))
=
make_bfloat162
(
val
,
val
);
*
(
nv_bfloat162
*
)(
&
(
data
[
i
].
z
))
=
make_bfloat162
(
val
,
val
);
*
(
nv_bfloat162
*
)(
&
(
data
[
i
].
w
))
=
make_bfloat162
(
val
,
val
);
}
}
FLASHINFER_INLINE
void
load
(
const
nv_bfloat16
*
ptr
)
{
#pragma unoll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
data
[
i
]
=
((
uint4
*
)
ptr
)[
i
];
}
}
FLASHINFER_INLINE
void
store
(
nv_bfloat16
*
ptr
)
const
{
#pragma unoll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
((
uint4
*
)
ptr
)[
i
]
=
data
[
i
];
}
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
vec_size
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
nv_bfloat16
*
dst
,
const
nv_bfloat16
*
src
)
{
#pragma unoll
for
(
size_t
i
=
0
;
i
<
vec_size
/
8
;
++
i
)
{
((
uint4
*
)
dst
)[
i
]
=
((
uint4
*
)
src
)[
i
];
}
}
};
/******************* vec_t<float> *******************/
// float x 1
template
<
>
struct
vec_t
<
float
,
1
>
{
float
data
;
FLASHINFER_INLINE
float
&
operator
[](
size_t
i
)
{
return
((
float
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
float
&
operator
[](
size_t
i
)
const
{
return
((
const
float
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
float
val
);
FLASHINFER_INLINE
void
load
(
const
float
*
ptr
);
FLASHINFER_INLINE
void
store
(
float
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
1
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
float
*
dst
,
const
float
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
float
,
1
>::
fill
(
float
val
)
{
data
=
val
;
}
FLASHINFER_INLINE
void
vec_t
<
float
,
1
>::
load
(
const
float
*
ptr
)
{
data
=
*
ptr
;
}
FLASHINFER_INLINE
void
vec_t
<
float
,
1
>::
store
(
float
*
ptr
)
const
{
*
ptr
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
float
,
1
>::
memcpy
(
float
*
dst
,
const
float
*
src
)
{
*
dst
=
*
src
;
}
// float x 2
template
<
>
struct
vec_t
<
float
,
2
>
{
float2
data
;
FLASHINFER_INLINE
float
&
operator
[](
size_t
i
)
{
return
((
float
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
const
float
&
operator
[](
size_t
i
)
const
{
return
((
const
float
*
)(
&
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
float
val
);
FLASHINFER_INLINE
void
load
(
const
float
*
ptr
);
FLASHINFER_INLINE
void
store
(
float
*
ptr
)
const
;
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
2
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
float
*
dst
,
const
float
*
src
);
};
FLASHINFER_INLINE
void
vec_t
<
float
,
2
>::
fill
(
float
val
)
{
data
=
make_float2
(
val
,
val
);
}
FLASHINFER_INLINE
void
vec_t
<
float
,
2
>::
load
(
const
float
*
ptr
)
{
data
=
*
((
float2
*
)
ptr
);
}
FLASHINFER_INLINE
void
vec_t
<
float
,
2
>::
store
(
float
*
ptr
)
const
{
*
((
float2
*
)
ptr
)
=
data
;
}
FLASHINFER_INLINE
void
vec_t
<
float
,
2
>::
memcpy
(
float
*
dst
,
const
float
*
src
)
{
*
((
float2
*
)
dst
)
=
*
((
float2
*
)
src
);
}
// float x 4 or more
template
<
size_t
vec_size
>
struct
vec_t
<
float
,
vec_size
>
{
float4
data
[
vec_size
/
4
];
FLASHINFER_INLINE
float
&
operator
[](
size_t
i
)
{
return
((
float
*
)(
data
))[
i
];
}
FLASHINFER_INLINE
const
float
&
operator
[](
size_t
i
)
const
{
return
((
const
float
*
)(
data
))[
i
];
}
FLASHINFER_INLINE
void
fill
(
float
val
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
data
[
i
]
=
make_float4
(
val
,
val
,
val
,
val
);
}
}
FLASHINFER_INLINE
void
load
(
const
float
*
ptr
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
data
[
i
]
=
((
float4
*
)
ptr
)[
i
];
}
}
FLASHINFER_INLINE
void
store
(
float
*
ptr
)
const
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
((
float4
*
)
ptr
)[
i
]
=
data
[
i
];
}
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_from
(
const
vec_t
<
T
,
vec_size
>
&
src
)
{
cast_from_impl
(
src
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_load
(
const
T
*
ptr
)
{
cast_load_impl
(
ptr
,
*
this
);
}
template
<
typename
T
>
FLASHINFER_INLINE
void
cast_store
(
T
*
ptr
)
const
{
cast_store_impl
(
*
this
,
ptr
);
}
FLASHINFER_INLINE
static
void
memcpy
(
float
*
dst
,
const
float
*
src
)
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
((
float4
*
)
dst
)[
i
]
=
((
float4
*
)
src
)[
i
];
}
}
};
/******************* vec_t type cast *******************/
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
half
,
vec_size
>
&
src
,
vec_t
<
float
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
float
(
src
.
data
);
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
2
;
++
i
)
{
((
float2
*
)(
&
dst
.
data
))[
i
]
=
__half22float2
(((
half2
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
float
,
vec_size
>
&
src
,
vec_t
<
half
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
half
(
src
.
data
);
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
2
;
++
i
)
{
((
half2
*
)(
&
dst
.
data
))[
i
]
=
__float22half2_rn
(((
float2
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
nv_bfloat16
,
vec_size
>
&
src
,
vec_t
<
float
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
float
(
src
.
data
);
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
2
;
++
i
)
{
((
float2
*
)(
&
dst
.
data
))[
i
]
=
__bfloat1622float2
(((
nv_bfloat162
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
float
,
vec_size
>
&
src
,
vec_t
<
nv_bfloat16
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
nv_bfloat16
(
src
.
data
);
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
2
;
++
i
)
{
((
nv_bfloat162
*
)(
&
dst
.
data
))[
i
]
=
__float22bfloat162_rn
(((
float2
*
)(
&
src
.
data
))[
i
]);
}
}
}
#ifdef FLASHINFER_USE_FP8
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
__nv_fp8_e4m3
,
vec_size
>
&
src
,
vec_t
<
float
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
float
(
src
.
data
);
}
else
if
constexpr
(
vec_size
==
2
)
{
*
(
float2
*
)(
&
dst
.
data
)
=
float2
(
*
(
__nv_fp8x2_e4m3
*
)(
&
src
.
data
));
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
((
float4
*
)(
&
dst
.
data
))[
i
]
=
float4
(((
__nv_fp8x4_e4m3
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
__nv_fp8_e4m3
,
vec_size
>
&
src
,
vec_t
<
half
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
float
(
src
.
data
);
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
2
;
++
i
)
{
((
half2
*
)(
&
dst
.
data
))[
i
]
=
half2
(((
__nv_fp8x2_e4m3
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
float
,
vec_size
>
&
src
,
vec_t
<
__nv_fp8_e4m3
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
__nv_fp8_e4m3
(
src
.
data
);
}
else
if
constexpr
(
vec_size
==
2
)
{
*
(
__nv_fp8x2_e4m3
*
)(
&
dst
.
data
)
=
__nv_fp8x2_e4m3
(
*
(
float2
*
)(
&
src
.
data
));
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
((
__nv_fp8x4_e4m3
*
)(
&
dst
.
data
))[
i
]
=
__nv_fp8x4_e4m3
(((
float4
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
half
,
vec_size
>
&
src
,
vec_t
<
__nv_fp8_e4m3
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
__nv_fp8_e4m3
(
src
.
data
);
}
else
if
constexpr
(
vec_size
==
2
)
{
*
(
__nv_fp8x2_e4m3
*
)(
&
dst
.
data
)
=
__nv_fp8x2_e4m3
(
*
(
half2
*
)(
&
src
.
data
));
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((
__nv_fp8x4_e4m3
*
)(
&
dst
.
data
))[
i
]
=
__nv_fp8x4_e4m3
(
((
half2
*
)(
&
src
.
data
))[
i
*
2
],
((
half2
*
)(
&
src
.
data
))[
i
*
2
+
1
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
__nv_fp8_e5m2
,
vec_size
>
&
src
,
vec_t
<
float
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
float
(
src
.
data
);
}
else
if
constexpr
(
vec_size
==
2
)
{
*
(
float2
*
)(
&
dst
.
data
)
=
float2
(
*
(
__nv_fp8x2_e5m2
*
)(
&
src
.
data
));
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
((
float4
*
)(
&
dst
.
data
))[
i
]
=
float4
(((
__nv_fp8x4_e5m2
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
__nv_fp8_e5m2
,
vec_size
>
&
src
,
vec_t
<
half
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
float
(
src
.
data
);
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
2
;
++
i
)
{
((
half2
*
)(
&
dst
.
data
))[
i
]
=
half2
(((
__nv_fp8x2_e5m2
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
float
,
vec_size
>
&
src
,
vec_t
<
__nv_fp8_e5m2
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
__nv_fp8_e5m2
(
src
.
data
);
}
else
if
constexpr
(
vec_size
==
2
)
{
*
(
__nv_fp8x2_e5m2
*
)(
&
dst
.
data
)
=
__nv_fp8x2_e5m2
(
*
(
float2
*
)(
&
src
.
data
));
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
((
__nv_fp8x4_e5m2
*
)(
&
dst
.
data
))[
i
]
=
__nv_fp8x4_e5m2
(((
float4
*
)(
&
src
.
data
))[
i
]);
}
}
}
template
<
size_t
vec_size
>
FLASHINFER_INLINE
void
cast_from_impl
(
const
vec_t
<
half
,
vec_size
>
&
src
,
vec_t
<
__nv_fp8_e5m2
,
vec_size
>
&
dst
)
{
if
constexpr
(
vec_size
==
1
)
{
dst
.
data
=
__nv_fp8_e4m3
(
src
.
data
);
}
else
if
constexpr
(
vec_size
==
2
)
{
*
(
__nv_fp8x2_e5m2
*
)(
&
dst
.
data
)
=
__nv_fp8x2_e5m2
(
*
(
half2
*
)(
&
src
.
data
));
}
else
{
#pragma unroll
for
(
size_t
i
=
0
;
i
<
vec_size
/
4
;
++
i
)
{
// NOTE(Zihao): need to double check if we properly handle flo and fhi
((
__nv_fp8x4_e5m2
*
)(
&
dst
.
data
))[
i
]
=
__nv_fp8x4_e5m2
(
((
half2
*
)(
&
src
.
data
))[
i
*
2
],
((
half2
*
)(
&
src
.
data
))[
i
*
2
+
1
]);
}
}
}
#endif // FLASHINFER_USE_FP8
#endif // VEC_DTYPES_CUH_
csrc/punica/punica_ops.cc
0 → 100644
View file @
e00b0a19
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <torch/extension.h>
#include <cstdint>
#include "bgmv/bgmv_config.h"
namespace
{
//====== utils ======
inline
void
check_shape
(
const
torch
::
Tensor
&
a
,
const
torch
::
Tensor
&
b
,
const
char
*
a_name
,
const
char
*
b_name
)
{
TORCH_CHECK
(
a
.
dim
()
==
b
.
dim
(),
a_name
,
".dim() != "
,
b_name
,
".dim(). "
,
a
.
dim
(),
" vs "
,
b
.
dim
());
for
(
int
i
=
0
;
i
<
a
.
dim
();
++
i
)
{
TORCH_CHECK
(
a
.
size
(
i
)
==
b
.
size
(
i
),
a_name
,
".size("
,
i
,
") != "
,
b_name
,
".size("
,
i
,
")"
);
}
}
inline
constexpr
uint32_t
pack_u16
(
uint16_t
a
,
uint16_t
b
)
{
return
(
uint32_t
(
a
)
<<
16
)
|
uint32_t
(
b
);
}
#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor")
#define CHECK_CONTIGUOUS(x) \
TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
#define CHECK_INPUT(x) \
CHECK_CUDA(x); \
CHECK_CONTIGUOUS(x)
#define CHECK_DIM(d, x) \
TORCH_CHECK(x.dim() == d, #x " must be a " #d "D tensor")
#define CHECK_SHAPE(a, b) check_shape(a, b, #a, #b)
#define CHECK_EQ(a, b) \
TORCH_CHECK(a == b, "CHECK_EQ(" #a ", " #b ") failed. ", a, " vs ", b)
//====== bgmv ======
template
<
typename
in_T
,
typename
out_T
,
typename
W_T
>
inline
bool
launch_bgmv_kernel
(
out_T
*
Y
,
const
in_T
*
X
,
const
W_T
*
W
,
const
int64_t
*
lora_indices
,
uint16_t
in_features
,
uint16_t
out_features
,
int64_t
y_offset
,
int64_t
full_y_size
,
int64_t
batch_size
,
int64_t
num_layers
,
int64_t
layer_idx
,
float
scale
)
{
switch
(
pack_u16
(
in_features
,
out_features
))
{
#define CASE_ONESIDE(_in_T, _out_T, _W_T, feat_in, feat_out) \
case pack_u16(feat_in, feat_out): \
bgmv_kernel<feat_in, feat_out>(Y, X, W, lora_indices, y_offset, \
full_y_size, batch_size, num_layers, \
layer_idx, scale); \
break;
#define CASE(_in_T, _out_T, _W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, narrow, wide) \
CASE_ONESIDE(in_T, out_T, W_T, wide, narrow)
FOR_BGMV_WIDE_NARROW
(
CASE
,
_
,
_
,
_
)
#undef CASE
#undef CASE_ONESIDE
default:
return
false
;
}
return
true
;
}
void
dispatch_bgmv
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
float
scale
)
{
CHECK_INPUT
(
y
);
CHECK_INPUT
(
x
);
CHECK_INPUT
(
w
);
CHECK_INPUT
(
indicies
);
CHECK_DIM
(
2
,
y
);
CHECK_DIM
(
2
,
x
);
CHECK_DIM
(
4
,
w
);
CHECK_DIM
(
1
,
indicies
);
int64_t
B
=
x
.
size
(
0
);
int64_t
h_in
=
x
.
size
(
1
);
int64_t
h_out
=
y
.
size
(
1
);
int64_t
num_layers
=
w
.
size
(
1
);
CHECK_EQ
(
w
.
size
(
3
),
h_in
);
CHECK_EQ
(
w
.
size
(
2
),
h_out
);
CHECK_EQ
(
indicies
.
size
(
0
),
x
.
size
(
0
));
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
// TODO: See if we can get rid of this massive nested switch
switch
(
x
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
y
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
y
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
y
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
0
,
h_out
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
default:
break
;
}
break
;
default:
break
;
}
}
TORCH_CHECK
(
ok
,
"No suitable kernel."
,
" h_in="
,
h_in
,
" h_out="
,
h_out
,
" dtype="
,
x
.
scalar_type
(),
" out_dtype="
,
y
.
scalar_type
());
}
void
dispatch_bgmv_low_level
(
torch
::
Tensor
y
,
torch
::
Tensor
x
,
torch
::
Tensor
w
,
torch
::
Tensor
indicies
,
int64_t
layer_idx
,
float
scale
,
int64_t
h_in
,
int64_t
h_out
,
int64_t
y_offset
)
{
CHECK_INPUT
(
y
);
CHECK_INPUT
(
x
);
CHECK_INPUT
(
w
);
CHECK_INPUT
(
indicies
);
CHECK_DIM
(
2
,
y
);
CHECK_DIM
(
2
,
x
);
CHECK_DIM
(
4
,
w
);
CHECK_DIM
(
1
,
indicies
);
int64_t
B
=
x
.
size
(
0
);
int64_t
num_layers
=
w
.
size
(
1
);
int64_t
full_y_size
=
y
.
size
(
1
);
CHECK_EQ
(
w
.
size
(
3
),
h_in
);
CHECK_EQ
(
w
.
size
(
2
),
h_out
);
CHECK_EQ
(
indicies
.
size
(
0
),
x
.
size
(
0
));
CHECK_EQ
(
y
.
size
(
0
),
x
.
size
(
0
));
bool
ok
=
false
;
if
(
h_in
<
65536
&&
h_out
<
65536
)
{
// TODO: See if we can get rid of this massive nested switch
switch
(
x
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
y
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_half
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
y
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
y
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_half
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
BFloat16
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
nv_bfloat16
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
case
at
::
ScalarType
::
Float
:
switch
(
w
.
scalar_type
())
{
case
at
::
ScalarType
::
Half
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_half
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
case
at
::
ScalarType
::
BFloat16
:
ok
=
launch_bgmv_kernel
(
static_cast
<
float
*>
(
y
.
data_ptr
()),
static_cast
<
float
*>
(
x
.
data_ptr
()),
static_cast
<
nv_bfloat16
*>
(
w
.
data_ptr
()),
indicies
.
data_ptr
<
int64_t
>
(),
h_in
,
h_out
,
y_offset
,
full_y_size
,
B
,
num_layers
,
layer_idx
,
scale
);
break
;
default:
break
;
}
break
;
default:
break
;
}
break
;
default:
break
;
}
}
TORCH_CHECK
(
ok
,
"No suitable kernel."
,
" h_in="
,
h_in
,
" h_out="
,
h_out
,
" dtype="
,
x
.
scalar_type
(),
" out_dtype="
,
y
.
scalar_type
());
}
}
// namespace
//====== pybind ======
#define DEFINE_pybind(name) m.def(#name, &name, #name);
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"dispatch_bgmv"
,
&
dispatch_bgmv
,
"dispatch_bgmv"
);
m
.
def
(
"dispatch_bgmv_low_level"
,
&
dispatch_bgmv_low_level
,
"dispatch_bgmv_low_level"
);
}
csrc/pybind.cpp
View file @
e00b0a19
...
...
@@ -22,6 +22,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
...
...
@@ -48,13 +52,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
// Quantization ops
#ifndef USE_ROCM
// Quantization ops
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
...
...
@@ -71,9 +82,9 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"
gather_cached_kv
"
,
&
gather_cached_kv
,
"
Ga
the
r
key and value
from the
cache
in
to
contiguous QKV tensors
"
);
"
convert_fp8_e5m2
"
,
&
convert_fp8_e5m2
,
"
Convert
the key and value cache to
fp8_e5m2 data type
"
);
// Cuda utils
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
...
...
@@ -81,4 +92,26 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
#ifndef USE_ROCM
// Custom all-reduce kernels
pybind11
::
module
custom_ar
=
m
.
def_submodule
(
"custom_ar"
,
"custom allreduce"
);
custom_ar
.
def
(
"init_custom_ar"
,
&
init_custom_ar
,
"init_custom_ar"
);
custom_ar
.
def
(
"should_custom_ar"
,
&
should_custom_ar
,
"should_custom_ar"
);
custom_ar
.
def
(
"all_reduce_reg"
,
&
all_reduce_reg
,
"all_reduce_reg"
);
custom_ar
.
def
(
"all_reduce_unreg"
,
&
all_reduce_unreg
,
"all_reduce_unreg"
);
custom_ar
.
def
(
"dispose"
,
&
dispose
,
"dispose"
);
custom_ar
.
def
(
"meta_size"
,
&
meta_size
,
"meta_size"
);
custom_ar
.
def
(
"register_buffer"
,
&
register_buffer
,
"register_buffer"
);
custom_ar
.
def
(
"get_graph_buffer_ipc_meta"
,
&
get_graph_buffer_ipc_meta
,
"get_graph_buffer_ipc_meta"
);
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"register_graph_buffers"
);
#endif
}
Prev
1
2
3
4
5
6
7
…
12
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment