Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f6d10c1
Unverified
Commit
5f6d10c1
authored
May 22, 2024
by
Michael Goin
Committed by
GitHub
May 22, 2024
Browse files
[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)
parent
9b9a10d6
Changes
64
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1701 additions
and
1872 deletions
+1701
-1872
csrc/cuda_utils_kernels.cu
csrc/cuda_utils_kernels.cu
+17
-23
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+27
-28
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+51
-54
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+19
-19
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+20
-22
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+121
-121
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+2
-1
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+3
-5
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+110
-101
csrc/ops.h
csrc/ops.h
+121
-209
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+103
-126
csrc/pybind.cpp
csrc/pybind.cpp
+56
-86
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+211
-325
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+76
-62
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+357
-254
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
+19
-19
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
+11
-11
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
+24
-23
csrc/quantization/fp8/amd/hip_float8.h
csrc/quantization/fp8/amd/hip_float8.h
+93
-123
csrc/quantization/fp8/amd/hip_float8_impl.h
csrc/quantization/fp8/amd/hip_float8_impl.h
+260
-260
No files found.
csrc/cuda_utils_kernels.cu
View file @
5f6d10c1
...
...
@@ -2,28 +2,22 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
device
,
value
;
if
(
device_id
<
0
)
{
cudaGetDevice
(
&
device
);
}
else
{
}
else
{
device
=
device_id
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
return
value
;
}
int
get_max_shared_memory_per_block_device_attribute
(
int
device_id
)
{
int
attribute
;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int
get_max_shared_memory_per_block_device_attribute
(
int
device_id
)
{
int
attribute
;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM
attribute
=
hipDeviceAttributeMaxSharedMemoryPerBlock
;
...
...
csrc/custom_all_reduce.cu
View file @
5f6d10c1
...
...
@@ -7,11 +7,11 @@
// fake pointer type
using
fptr_t
=
uint64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
)
{
int
world_size
=
offsets
.
size
();
if
(
world_size
>
8
)
...
...
@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std
::
memcpy
(
&
ipc_handles
[
i
],
handles
[
i
].
data
(),
sizeof
(
cudaIpcMemHandle_t
));
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
ipc_handles
,
offsets
,
rank
,
full_nvlink
);
}
...
...
@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
return
t
.
is_contiguous
()
||
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
t
.
numel
()
*
t
.
element_size
());
}
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
full_nvlink
)
{
auto
inp_size
=
inp
.
numel
()
*
inp
.
element_size
();
// custom allreduce requires input byte size to be multiples of 16
...
...
@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return
false
;
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
cudaStream_t
stream
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#endif
...
...
@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
}
}
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
...
...
@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce
(
_fa
,
inp
,
out
,
stream
);
}
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
)
{
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
}
void
dispose
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
delete
fa
;
}
int
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_buffer
(
handles
,
offsets
,
t
.
data_ptr
());
}
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
return
fa
->
get_graph_buffer_ipc_meta
();
}
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_graph_buffers
(
handles
,
offsets
);
}
csrc/custom_all_reduce.cuh
View file @
5f6d10c1
...
...
@@ -31,9 +31,9 @@ struct Signal {
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
...
...
@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
...
...
@@ -80,14 +80,14 @@ template <>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
return
__float2bfloat16
(
val
);
}
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
#endif
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>
&
packed_assign_add
(
array_t
<
T
,
N
>
&
a
,
array_t
<
T
,
N
>
b
)
{
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
assign_add
(
a
.
data
[
i
],
b
.
data
[
i
]);
...
...
@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
...
...
@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
])
;
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]);
}
__syncthreads
();
}
...
...
@@ -147,7 +146,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
...
...
@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
])
;
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]);
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
#pragma unroll
for
(
int
i
=
1
;
i
<
ngpus
;
i
++
)
{
...
...
@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
...
...
@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
P
>
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
...
...
@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
...
...
@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
...
...
@@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers
RankSignals
sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
// stores the registered device pointers from all ranks
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
...
...
@@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
self_sg_
(
meta
),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Signal
*
rank_sg
;
Signal
*
rank_sg
;
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
handle
+=
offsets
[
i
];
rank_sg
=
(
Signal
*
)
handle
;
rank_sg
=
(
Signal
*
)
handle
;
}
else
{
rank_sg
=
self_sg_
;
}
...
...
@@ -302,13 +299,13 @@ class CustomAllreduce {
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
char
*
ipc_ptr
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
...
...
@@ -323,7 +320,7 @@ class CustomAllreduce {
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
graph_unreg_buffers_
[
i
];
void
*
base_ptr
;
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
...
...
@@ -331,8 +328,8 @@ class CustomAllreduce {
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
cudaIpcGetMemHandle
(
(
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
(
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
return
std
::
make_pair
(
handles
,
offsets
);
}
...
...
@@ -344,13 +341,13 @@ class CustomAllreduce {
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
void
*
self
)
{
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
check_rank_data_capacity
();
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
handle
+=
offsets
[
i
];
data
.
ptrs
[
i
]
=
handle
;
}
else
{
...
...
@@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
)
{
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
graph_unreg_buffers_
[
i
];
auto
&
rd
=
rank_data
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
if
(
j
!=
rank_
)
{
char
*
handle
=
char
*
handle
=
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
...
...
@@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
...
...
@@ -418,7 +415,7 @@ class CustomAllreduce {
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
...
...
csrc/custom_all_reduce_test.cu
View file @
5f6d10c1
...
...
@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
}
template
<
typename
T
>
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
data
[
idx
]
=
myRank
*
0.11
f
;
...
...
@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
}
template
<
typename
T
>
__global__
void
convert_data
(
const
T
*
data1
,
const
T
*
data2
,
double
*
fdata1
,
double
*
fdata2
,
int
size
)
{
__global__
void
convert_data
(
const
T
*
data1
,
const
T
*
data2
,
double
*
fdata1
,
double
*
fdata2
,
int
size
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
fdata1
[
idx
]
=
data1
[
idx
];
...
...
@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
}
}
__global__
void
init_rand
(
curandState_t
*
state
,
int
size
,
int
nRanks
)
{
__global__
void
init_rand
(
curandState_t
*
state
,
int
size
,
int
nRanks
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
...
...
@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
}
template
<
typename
T
>
__global__
void
gen_data
(
curandState_t
*
state
,
T
*
data
,
double
*
ground_truth
,
__global__
void
gen_data
(
curandState_t
*
state
,
T
*
data
,
double
*
ground_truth
,
int
myRank
,
int
nRanks
,
int
size
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
...
@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
}
template
<
typename
T
>
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
int
data_size
,
bool
performance_test
)
{
T
*
result
;
T
*
result
;
cudaStream_t
stream
;
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaMalloc
(
&
result
,
data_size
*
sizeof
(
T
)));
...
...
@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
data_handles
[
8
];
vllm
::
Signal
*
buffer
;
T
*
self_data_copy
;
vllm
::
Signal
*
buffer
;
T
*
self_data_copy
;
/**
* Allocate IPC buffer
*
...
...
@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE
,
data_handles
,
sizeof
(
cudaIpcMemHandle_t
),
MPI_BYTE
,
MPI_COMM_WORLD
));
void
*
rank_data
;
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
offsets
,
myRank
);
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
{
std
::
vector
<
std
::
string
>
handles
;
handles
.
reserve
(
nRanks
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
}
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
...
...
@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
double
*
ground_truth
;
double
*
ground_truth
;
CUDACHECK
(
cudaMallocHost
(
&
ground_truth
,
data_size
*
sizeof
(
double
)));
curandState_t
*
states
;
curandState_t
*
states
;
CUDACHECK
(
cudaMalloc
(
&
states
,
sizeof
(
curandState_t
)
*
nRanks
*
data_size
));
init_rand
<<<
108
,
1024
,
0
,
stream
>>>
(
states
,
data_size
,
nRanks
);
gen_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
states
,
self_data
,
ground_truth
,
myRank
,
...
...
@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK
(
cudaStreamDestroy
(
stream
));
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
int
nRanks
,
myRank
;
MPICHECK
(
MPI_Init
(
&
argc
,
&
argv
));
MPICHECK
(
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
myRank
));
...
...
@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId
id
;
ncclComm_t
comm
;
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
...
...
csrc/dispatch_utils.h
View file @
5f6d10c1
...
...
@@ -12,8 +12,7 @@
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
...
...
@@ -22,8 +21,8 @@
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(
\
TYPE, NAME,
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(
TYPE, NAME,
\
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
...
...
@@ -33,5 +32,4 @@
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu
View file @
5f6d10c1
...
...
@@ -11,26 +11,24 @@
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
}
variance
=
blockReduceSum
<
float
>
(
variance
);
...
...
@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
...
...
@@ -56,46 +54,63 @@ __global__ void rms_norm_kernel(
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
*/
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template
<
>
template
<
>
struct
_typeConvert
<
c10
::
Half
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__half
;
using
packed_hip_type
=
__half2
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__half2float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template
<
>
template
<
>
struct
_typeConvert
<
c10
::
BFloat16
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__nv_bfloat16
;
using
packed_hip_type
=
__nv_bfloat162
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16Vec
{
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
...
...
@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__
_f16Vec
&
operator
+=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
float
scale
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
temp_f
.
x
*=
scale
;
temp_f
.
y
*=
scale
;
T2
temp
=
Converter
::
convert
(
temp_f
);
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
temp
=
Converter
::
convert
(
data
[
i
])
*
scale
;
data
[
i
]
=
Converter
::
convert
(
temp
);
...
...
@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__
float
sum_squares
()
const
{
float
result
=
0.0
f
;
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
result
+=
z
.
x
*
z
.
x
+
z
.
y
*
z
.
y
;
}
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
x
=
Converter
::
convert
(
data
[
i
]);
result
+=
x
*
x
;
...
...
@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
...
...
@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
...
...
@@ -218,7 +232,8 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -233,26 +248,23 @@ __global__ std::enable_if_t<
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
...
...
@@ -260,22 +272,23 @@ __global__ std::enable_if_t<
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
...
...
@@ -286,37 +299,24 @@ void rms_norm(
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"fused_add_rms_norm_kernel", \
[&] { \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
});
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
...
...
@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
\
&&
wt_ptr
%
16
==
0
;
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
...
...
csrc/moe/moe_ops.cpp
View file @
5f6d10c1
...
...
@@ -3,5 +3,6 @@
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
csrc/moe/moe_ops.h
View file @
5f6d10c1
...
...
@@ -2,8 +2,6 @@
#include <torch/extension.h>
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
csrc/moe_align_block_size_kernels.cu
View file @
5f6d10c1
...
...
@@ -7,32 +7,35 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
#define CEILDIV(x,
y) (((x) + (y) - 1) / (y))
namespace
vllm
{
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
// don't worry about overflow because num_experts is relatively small
return
row
*
total_col
+
col
;
}
}
}
// namespace
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
int32_t
block_size
,
size_t
numel
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
int32_t
*
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
int32_t
*
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
int32_t
*
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
...
...
@@ -40,8 +43,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
assigned
* to expert expert_index.
* which counts how many tokens in the token shard of thread_index are
*
assigned
to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
...
...
@@ -52,7 +55,8 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
// For each expert we accumulate the token counts from the different threads.
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
__syncthreads
();
...
...
@@ -61,7 +65,10 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
...
...
@@ -69,57 +76,59 @@ __global__ void moe_align_block_size_kernel(scalar_t *__restrict__ topk_ids,
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
blocks
* and stores the corresponding expert_id for each block.
* For each expert, each thread processes the tokens of the corresponding
*
blocks
and stores the corresponding expert_id for each block.
*/
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
}
}
// namespace vllm
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
const
int32_t
shared_mem
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const
int32_t
shared_mem
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
});
}
csrc/ops.h
View file @
5f6d10c1
...
...
@@ -2,224 +2,136 @@
#include <torch/extension.h>
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
#endif
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
using
fptr_t
=
uint64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
int
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
csrc/pos_encoding_kernels.cu
View file @
5f6d10c1
...
...
@@ -7,14 +7,10 @@
namespace
vllm
{
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_token_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
...
...
@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
...
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
...
...
@@ -68,59 +62,71 @@ inline __device__ void apply_rotary_embedding(
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
batched_rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len] or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
cos_sin_cache_offset
=
cos_sin_cache_offsets
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
}
// namespace vllm
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
...
...
@@ -135,33 +141,18 @@ void rotary_embedding(
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
...
...
@@ -173,12 +164,13 @@ and process in batched manner.
*/
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int
rot_dim
,
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
)
{
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
...
...
@@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/pybind.cpp
View file @
5f6d10c1
...
...
@@ -8,114 +8,85 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11
::
module
ops
=
m
.
def_submodule
(
"ops"
,
"vLLM custom operators"
);
// Attention ops
ops
.
def
(
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
ops
.
def
(
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
// Activation ops
ops
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
ops
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU with `none` approximation."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Activation function used in GeGLU with `tanh` approximation."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
// Layernorm
ops
.
def
(
"rms_norm"
,
&
rms_norm
,
ops
.
def
(
"rms_norm"
,
&
rms_norm
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
ops
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
ops
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
"In-place fused Add and RMS Normalization"
);
// Rotary embedding
ops
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
ops
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
ops
.
def
(
"batched_rotary_embedding"
,
&
batched_rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"
);
ops
.
def
(
"batched_rotary_embedding"
,
&
batched_rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
"(supports multiple loras)"
);
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"aqlm_gemm"
,
&
aqlm_gemm
,
"Quantized GEMM for AQLM"
);
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
,
"Decompression method for AQLM"
);
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
ops
.
def
(
"cutlass_scaled_mm_dq"
,
&
cutlass_scaled_mm_dq
,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."
);
ops
.
def
(
"cutlass_scaled_mm_dq"
,
&
cutlass_scaled_mm_dq
,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization."
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"static_scaled_fp8_quant"
,
&
static_scaled_fp8_quant
,
"Compute FP8 quantized tensor for given scaling factor"
);
ops
.
def
(
"dynamic_scaled_fp8_quant"
,
&
dynamic_scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
ops
.
def
(
"static_scaled_fp8_quant"
,
&
static_scaled_fp8_quant
,
"Compute FP8 quantized tensor for given scaling factor"
);
ops
.
def
(
"dynamic_scaled_fp8_quant"
,
&
dynamic_scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size."
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
cache_ops
.
def
(
"swap_blocks"
,
&
swap_blocks
,
cache_ops
.
def
(
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
cache_ops
.
def
(
"copy_blocks"
,
&
copy_blocks
,
cache_ops
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
"Convert the key and value cache to fp8 data type"
);
// Cuda utils
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
cuda_utils
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
cuda_utils
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
...
...
@@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"register_graph_buffers"
);
#endif
}
csrc/quantization/aqlm/gemm_kernels.cu
View file @
5f6d10c1
...
...
@@ -25,30 +25,26 @@
#include <iostream>
#include <cstdlib>
namespace
vllm
{
namespace
aqlm
{
__global__
void
Code1x16MatVec
(
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
const
int
prob_m
,
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
const
int
prob_m
,
const
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long.
const
int
codebook_stride
// as int4.
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
...
...
@@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
// We pad shared memory to avoid bank conflicts during reads
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
8
;
i
+=
blockDim
.
x
)
{
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
}
__syncthreads
();
b_gl_rd
+=
32
*
8
;
...
...
@@ -76,22 +71,19 @@ __global__ void Code1x16MatVec(
int
b_sh_rd
=
9
*
(
threadIdx
.
x
%
32
);
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint16_t
*
enc
=
reinterpret_cast
<
const
uint16_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
dec
[
4
];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
dec
[
0
]),
"=r"
(
dec
[
1
]),
"=r"
(
dec
[
2
]),
"=r"
(
dec
[
3
])
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]])
);
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]]));
half2
*
a
=
reinterpret_cast
<
half2
*>
(
&
dec
);
half2
*
b
=
reinterpret_cast
<
half2
*>
(
&
sh_b
[
b_sh_rd
]);
half2
res2
=
{};
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
res2
=
__hfma2
(
a
[
j
],
b
[
j
],
res2
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
res2
=
__hfma2
(
a
[
j
],
b
[
j
],
res2
);
res
+=
__half2float
(
res2
.
x
)
+
__half2float
(
res2
.
y
);
b_sh_rd
++
;
}
...
...
@@ -100,22 +92,19 @@ __global__ void Code1x16MatVec(
}
if
(
pred
)
{
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
if
(
threadIdx
.
x
%
32
==
0
)
reinterpret_cast
<
__half
*>
(
C
)[
c_gl_wr
]
=
__float2half
(
res
);
}
}
__global__
void
Code2x8MatVec
(
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long.
const
int
codebook_stride
// as int4.
)
{
...
...
@@ -123,12 +112,11 @@ __global__ void Code2x8MatVec(
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
...
...
@@ -148,9 +136,8 @@ __global__ void Code2x8MatVec(
for
(
int
i
=
threadIdx
.
x
;
i
<
2
*
256
;
i
+=
blockDim
.
x
)
{
int4
dec
=
codebook
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
}
__syncthreads
();
...
...
@@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
// We pad shared memory to avoid bank conflicts during reads
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
8
;
i
+=
blockDim
.
x
)
{
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
}
__syncthreads
();
b_gl_rd
+=
32
*
8
;
...
...
@@ -170,13 +156,15 @@ __global__ void Code2x8MatVec(
int
b_sh_rd
=
9
*
(
threadIdx
.
x
%
32
);
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint8_t
*
enc
=
reinterpret_cast
<
const
uint8_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
half2
*
b
=
reinterpret_cast
<
half2
*>
(
&
sh_b
[
b_sh_rd
]);
half2
res2
=
{};
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
res2
=
__hfma2
(
__hadd2
(
a0
[
j
],
a1
[
j
]),
b
[
j
],
res2
);
res
+=
__half2float
(
res2
.
x
)
+
__half2float
(
res2
.
y
);
...
...
@@ -187,34 +175,29 @@ __global__ void Code2x8MatVec(
}
if
(
pred
)
{
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
if
(
threadIdx
.
x
%
32
==
0
)
reinterpret_cast
<
__half
*>
(
C
)[
c_gl_wr
]
=
__float2half
(
res
);
}
}
__global__
void
Code1x16Dequant
(
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long, sums to m.
const
int
codebook_stride
// as int4
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
...
...
@@ -231,17 +214,15 @@ __global__ void Code1x16Dequant(
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint16_t
*
enc
=
reinterpret_cast
<
const
uint16_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int4
chunk
;
auto
dec
=
reinterpret_cast
<
uint32_t
*>
(
&
chunk
);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
dec
[
0
]),
"=r"
(
dec
[
1
]),
"=r"
(
dec
[
2
]),
"=r"
(
dec
[
3
])
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]])
);
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]]));
C
[
a_gl_rd
*
8
+
i
]
=
chunk
;
}
...
...
@@ -250,26 +231,23 @@ __global__ void Code1x16Dequant(
}
}
__global__
void
Code2x8Dequant
(
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const
int
codebook_stride
// as int4
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
...
...
@@ -290,9 +268,8 @@ __global__ void Code2x8Dequant(
for
(
int
i
=
threadIdx
.
x
;
i
<
2
*
256
;
i
+=
blockDim
.
x
)
{
int4
dec
=
codebook
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
}
__syncthreads
();
...
...
@@ -302,12 +279,14 @@ __global__ void Code2x8Dequant(
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint8_t
*
enc
=
reinterpret_cast
<
const
uint8_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int4
chunk
;
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
#pragma unroll
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
half2
*>
(
&
chunk
)[
j
]
=
__hadd2
(
a0
[
j
],
a1
[
j
]);
C
[
a_gl_rd
*
8
+
i
]
=
chunk
;
...
...
@@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
}
}
inline
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
const
int
THREAD_M
=
16
;
void
code1x16_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
void
code1x16_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
0
);
int
waves
=
0
;
...
...
@@ -345,28 +317,16 @@ void code1x16_matvec_cuda(
int
blocks
=
ceildiv
(
prob_m
,
thread_m
);
int
threads
=
32
*
thread_m
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Code1x16MatVec
<<<
blocks
,
threads
,
16
*
32
*
9
,
stream
>>>
(
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
Code1x16MatVec
<<<
blocks
,
threads
,
16
*
32
*
9
,
stream
>>>
(
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
}
void
code2x8_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
code2x8_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
0
);
int
waves
=
0
;
...
...
@@ -379,29 +339,19 @@ void code2x8_matvec_cuda(
int
blocks
=
ceildiv
(
prob_m
,
thread_m
);
int
threads
=
32
*
thread_m
;
int
shared
=
16
*
(
2
*
256
*
8
+
32
*
9
);
cudaFuncSetAttribute
(
Code2x8MatVec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
cudaFuncSetAttribute
(
Code2x8MatVec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Code2x8MatVec
<<<
blocks
,
threads
,
shared
,
stream
>>>
(
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
}
void
code1x16_dequant_cuda
(
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long.
const
int
codebook_stride
// as int4.
)
{
int
sms
;
...
...
@@ -417,24 +367,20 @@ void code1x16_dequant_cuda(
int
threads
=
32
*
thread_m
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Code1x16Dequant
<<<
blocks
,
threads
,
0
,
stream
>>>
(
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at
// most 3 long.
codebook_stride
// as int4.
);
}
// Dequantizes the code and codebook into weights.
void
code2x8_dequant_cuda
(
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const
int
codebook_stride
// as int4
)
{
int
sms
;
...
...
@@ -451,50 +397,33 @@ void code2x8_dequant_cuda(
int
shared
=
16
*
(
2
*
256
*
8
+
32
*
9
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaFuncSetAttribute
(
Code2x8Dequant
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
cudaFuncSetAttribute
(
Code2x8Dequant
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
Code2x8Dequant
<<<
blocks
,
threads
,
shared
,
stream
>>>
(
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
}
int
codebook_stride
(
const
torch
::
Tensor
&
codebooks
)
{
int
codebook_stride
(
const
torch
::
Tensor
&
codebooks
)
{
return
codebooks
.
stride
(
0
)
*
codebooks
.
element_size
()
/
sizeof
(
int4
);
}
void
code1x16_matvec
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
// cumulative sizes of A spanning each codebook, at most 3 long.
const
int4
codebook_a_sizes
// cumulative sizes of A spanning each
// codebook, at most 3 long.
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
int
prob_m
=
C
.
size
(
0
);
int
prob_k
=
B
.
size
(
0
);
code1x16_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
(
codebook
)
);
code1x16_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
(
codebook
));
}
torch
::
Tensor
code1x16_matmat
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
code1x16_matmat
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
...
...
@@ -503,22 +432,15 @@ torch::Tensor code1x16_matmat(
auto
input_sizes
=
input
.
sizes
();
auto
out_features
=
codes
.
size
(
0
)
*
codebooks
.
size
(
2
);
auto
flat_input
=
input
.
reshape
({
-
1
,
input
.
size
(
-
1
)});
auto
flat_output
=
torch
::
empty
({
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
())
);
auto
flat_output
=
torch
::
empty
(
{
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
().
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
for
(
int
i
=
0
;
i
<
flat_input
.
size
(
0
);
++
i
)
{
auto
input_vec
=
flat_input
.
index
({
i
});
auto
output_vec
=
flat_output
.
index
({
i
});
code1x16_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
code1x16_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
}
flat_output
*=
scales
.
flatten
().
unsqueeze
(
0
);
...
...
@@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
return
output
;
}
void
code2x8_matvec
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
)
{
void
code2x8_matvec
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
int
prob_m
=
C
.
size
(
0
);
int
prob_k
=
B
.
size
(
0
);
code2x8_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
2
*
codebook_stride
(
codebook
)
);
code2x8_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
2
*
codebook_stride
(
codebook
));
}
torch
::
Tensor
code2x8_matmat
(
const
torch
::
Tensor
&
input
,
torch
::
Tensor
code2x8_matmat
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
int4
codebook_a_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
input_sizes
=
input
.
sizes
();
auto
out_features
=
codes
.
size
(
0
)
*
codebooks
.
size
(
2
);
auto
flat_input
=
input
.
reshape
({
-
1
,
input
.
size
(
-
1
)});
auto
flat_output
=
torch
::
empty
({
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
())
);
auto
flat_output
=
torch
::
empty
(
{
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
().
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
for
(
int
i
=
0
;
i
<
flat_input
.
size
(
0
);
++
i
)
{
auto
input_vec
=
flat_input
.
index
({
i
});
auto
output_vec
=
flat_output
.
index
({
i
});
code2x8_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
code2x8_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
}
flat_output
*=
scales
.
flatten
().
unsqueeze
(
0
);
if
(
bias
.
has_value
())
{
...
...
@@ -596,22 +498,19 @@ torch::Tensor code2x8_matmat(
}
// Accumulate the partition sizes.
int4
accumulate_sizes
(
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
accumulate_sizes
(
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
;
auto
cumulative_size
=
&
cumulative_sizes
.
x
;
int
i
=
0
;
int
last
=
0
;
assert
(
codebook_partition_sizes
.
size
(
0
)
<=
4
);
for
(;
i
<
codebook_partition_sizes
.
size
(
0
);
++
i
,
++
cumulative_size
)
{
for
(;
i
<
codebook_partition_sizes
.
size
(
0
);
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
codebook_partition_sizes
[
i
].
item
<
int
>
()
+
last
;
last
=
*
cumulative_size
;
}
// fill in the rest with unreachable.
for
(;
i
<
4
;
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
last
*
10
;
for
(;
i
<
4
;
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
last
*
10
;
}
return
cumulative_sizes
;
}
...
...
@@ -619,41 +518,36 @@ int4 accumulate_sizes(const torch::Tensor& codebook_partition_sizes)
}
// namespace aqlm
}
// namespace vllm
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
entries
=
codebooks
.
size
(
1
);
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
return
vllm
::
aqlm
::
code1x16_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
return
vllm
::
aqlm
::
code1x16_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
}
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
return
vllm
::
aqlm
::
code2x8_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
return
vllm
::
aqlm
::
code2x8_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
}
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
return
{};
}
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
entries
=
codebooks
.
size
(
1
);
...
...
@@ -670,43 +564,35 @@ torch::Tensor aqlm_dequant(
auto
weights
=
torch
::
empty
({
out_features
,
in_features
},
torch
::
TensorOptions
()
.
dtype
(
codebooks
.
dtype
())
.
device
(
codebooks
.
device
())
);
.
device
(
codebooks
.
device
()));
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
vllm
::
aqlm
::
code1x16_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
vllm
::
aqlm
::
code1x16_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
vllm
::
aqlm
::
codebook_stride
(
codebooks
));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
// weights *= scales.index({"...", 0, 0});
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return
weights
;
}
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
vllm
::
aqlm
::
code2x8_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
vllm
::
aqlm
::
code2x8_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
vllm
::
aqlm
::
codebook_stride
(
codebooks
));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
// weights *= scales.index({"...", 0, 0});
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return
weights
;
}
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
return
{};
}
csrc/quantization/awq/dequantize.cuh
View file @
5f6d10c1
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
...
...
@@ -14,8 +14,7 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
namespace
vllm
{
namespace
awq
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert
(
false
);
#else
...
...
@@ -30,33 +29,40 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
static
constexpr
uint32_t
TOP_MASK
=
0x00f000f0
;
static
constexpr
uint32_t
I4s_TO_F16s_MAGIC_NUM
=
0x64006400
;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
dependency if we issue
// immediately before required.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
//
dependency if we issue
immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
...
...
@@ -71,13 +77,21 @@ __device__ uint4 dequantize_s4_to_fp16x2(uint32_t const& source)
// Finally, we construct the output numbers.
// Convert elt_01
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_23
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
// Convert elt_45
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_67
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
return
result
;
#endif
...
...
csrc/quantization/awq/gemm_kernels.cu
View file @
5f6d10c1
/*
Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -20,26 +18,20 @@ namespace vllm {
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
// Only support matrix n = 64 or 128
assert
(
N
==
64
||
N
==
128
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
...
...
@@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
N
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
(
256
/
N
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
)
)
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
N
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
half
*
C_ptr
=
C
+
static_cast
<
long
long
>
(
blockIdx_z
)
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
((
int
)
threadIdx
.
y
)
*
(
N
/
2
)
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
(
256
/
N
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
N
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
half
*
C_ptr
=
C
+
static_cast
<
long
long
>
(
blockIdx_z
)
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
((
int
)
threadIdx
.
y
)
*
(
N
/
2
)
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
...
...
@@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
{
if
(
ld_A_flag
)
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
}
else
{
}
else
{
*
(
uint4
*
)(
A_shared_ptr
)
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
N
/
16
;
++
ax0_ax1_fused_0
)
{
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
// zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
N
+
8
))
=
B_loaded_fp16
;
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
N
+
8
))
=
B_loaded_fp16
;
}
__syncthreads
();
...
...
@@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
)))));
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
)
);
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
));
}
for
(
int
ax1_0
=
0
;
ax1_0
<
N
/
32
;
++
ax1_0
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
(
N
*
16
+
128
))
+
(((
int
)
threadIdx
.
y
)
*
(
N
/
2
)))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
(
N
+
8
))
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
(
N
*
16
+
128
))
+
(((
int
)
threadIdx
.
y
)
*
(
N
/
2
)))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
(
N
+
8
))
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
)))));
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
)
);
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
));
}
}
for
(
int
j_0_4
=
0
;
j_0_4
<
N
/
32
;
++
j_0_4
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
#else
#else
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
"%13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
"%13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
#endif
#endif
}
}
}
// TODO: Shang: Hoist loop invariance.
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
4
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
}
}
}
#endif
}
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
...
...
@@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
uint32_t
B_loaded
=
*
(
uint32_t
*
)
B_ptr2
;
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
*
(
uint4
*
)
B_shared_ptr2
=
B_loaded_fp16
;
...
...
@@ -329,14 +433,10 @@ __global__ void __launch_bounds__(64) dequantize_weights(
}
// namespace awq
}
// namespace vllm
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
)
{
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
)
{
int
in_c
=
_kernel
.
size
(
0
);
int
qout_c
=
_kernel
.
size
(
1
);
int
out_c
=
qout_c
*
8
;
...
...
@@ -347,13 +447,13 @@ torch::Tensor awq_dequantize(
int
x_blocks
=
1
;
int
y_blocks
=
1
;
if
(
thx
==
0
)
{
if
(
thx
==
0
)
{
x_thread
=
qout_c
;
}
if
(
thy
==
0
)
{
if
(
thy
==
0
)
{
y_thread
=
in_c
;
}
if
(
thx
==
0
&&
thy
==
0
)
{
if
(
thx
==
0
&&
thy
==
0
)
{
x_thread
=
8
;
y_thread
=
8
;
x_blocks
=
(
int
)(
qout_c
/
8
);
...
...
@@ -362,12 +462,15 @@ torch::Tensor awq_dequantize(
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_scaling_factors
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
_scaling_factors
.
dtype
()).
device
(
_scaling_factors
.
device
());
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
_scaling_factors
.
dtype
())
.
device
(
_scaling_factors
.
device
());
at
::
Tensor
_de_kernel
=
torch
::
empty
({
in_c
,
out_c
},
options
);
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
de_kernel
=
reinterpret_cast
<
half
*>
(
_de_kernel
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
dim3
num_blocks
(
x_blocks
,
y_blocks
);
...
...
@@ -386,26 +489,26 @@ torch::Tensor awq_dequantize(
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
)
{
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_in_feats
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
_in_feats
.
dtype
()).
device
(
_in_feats
.
device
());
at
::
Tensor
_out_feats
=
torch
::
empty
({
split_k_iters
,
num_in_feats
,
_kernel
.
size
(
1
)
*
8
},
options
);
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
_in_feats
.
dtype
())
.
device
(
_in_feats
.
device
());
at
::
Tensor
_out_feats
=
torch
::
empty
({
split_k_iters
,
num_in_feats
,
_kernel
.
size
(
1
)
*
8
},
options
);
int
num_out_feats
=
_out_feats
.
size
(
-
2
);
int
num_out_channels
=
_out_feats
.
size
(
-
1
);
auto
in_feats
=
reinterpret_cast
<
half
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
...
...
@@ -419,28 +522,28 @@ torch::Tensor awq_gemm(
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
128
==
0
)
{
if
(
num_out_channels
%
128
==
0
)
{
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
128
><<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
128
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
int
j_factors1
=
num_out_channels
/
64
/
1
;
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.x: 32
// threadIdx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
64
><<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
64
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
View file @
5f6d10c1
...
...
@@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
};
template
<
typename
Gemm
>
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
...
...
@@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_scales_ptr
=
a_scales
.
data_ptr
<
float
>
();
auto
b_scales_ptr
=
b_scales
.
data_ptr
<
float
>
();
...
...
@@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
}
// namespace
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
}
}
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
}
}
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
View file @
5f6d10c1
...
...
@@ -120,10 +120,10 @@ struct cutlass_3x_gemm {
};
template
<
typename
Gemm
>
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
...
...
@@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
...
...
@@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
}
}
// namespace
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
View file @
5f6d10c1
...
...
@@ -2,29 +2,29 @@
#include <cuda_runtime.h>
#include <torch/extension.h>
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
int32_t
major_capability
;
int32_t
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
...
...
@@ -43,7 +43,8 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
...
...
csrc/quantization/fp8/amd/hip_float8.h
View file @
5f6d10c1
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#else
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#endif
#include "hip_float8_impl.h"
struct
alignas
(
1
)
hip_fp8
{
struct
from_bits_t
{
};
HIP_FP8_HOST_DEVICE
static
constexpr
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
struct
alignas
(
1
)
hip_fp8
{
struct
from_bits_t
{};
HIP_FP8_HOST_DEVICE
static
constexpr
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
uint8_t
data
;
hip_fp8
()
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
const
hip_fp8
&
)
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
)
=
delete
;
explicit
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
,
from_bits_t
)
:
data
(
v
)
{
}
:
data
(
v
)
{}
#ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias
explicit
HIP_FP8_DEVICE
hip_fp8
(
float
v
)
:
data
(
hip_fp8_impl
::
to_fp8_from_fp32
(
v
))
{
}
:
data
(
hip_fp8_impl
::
to_fp8_from_fp32
(
v
))
{}
explicit
HIP_FP8_DEVICE
hip_fp8
(
_Float16
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{
}
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{}
// Host only implementation using s/w simulation
explicit
HIP_FP8_HOST
...
...
@@ -45,25 +38,24 @@ struct alignas(1) hip_fp8
// both Host and DEVICE for non-MI300 using s/w simulation
explicit
HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8
(
float
v
)
{
data
=
hip_fp8_impl
::
to_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
);
hip_fp8
(
float
v
)
{
data
=
hip_fp8_impl
::
to_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
);
}
explicit
HIP_FP8_HOST_DEVICE
hip_fp8
(
double
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{
}
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{}
#ifdef __HIP__MI300__
// upcast using device specific intrinsic
explicit
inline
HIP_FP8_DEVICE
operator
float
()
const
{
explicit
inline
HIP_FP8_DEVICE
operator
float
()
const
{
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
asm
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
asm
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
return
fval
;
}
...
...
@@ -73,95 +65,73 @@ struct alignas(1) hip_fp8
explicit
inline
HIP_FP8_HOST_DEVICE
operator
float
()
const
#endif // __HIP__MI300__
{
return
hip_fp8_impl
::
from_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
return
hip_fp8_impl
::
from_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
};
namespace
std
{
inline
hip_fp8
sin
(
hip_fp8
a
)
{
return
hip_fp8
(
sinf
(
float
(
a
)));
}
inline
hip_fp8
cos
(
hip_fp8
a
)
{
return
hip_fp8
(
cosf
(
float
(
a
)));
}
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
real
(
const
hip_fp8
&
a
)
{
return
a
;
}
namespace
std
{
inline
hip_fp8
sin
(
hip_fp8
a
)
{
return
hip_fp8
(
sinf
(
float
(
a
)));
}
inline
hip_fp8
cos
(
hip_fp8
a
)
{
return
hip_fp8
(
cosf
(
float
(
a
)));
}
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
real
(
const
hip_fp8
&
a
)
{
return
a
;
}
}
// namespace std
// Special operator overloading
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
hip_fp8
&
f8
)
{
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
hip_fp8
&
f8
)
{
return
os
<<
float
(
f8
);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns
float
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
const
float
fa
,
hip_fp8
b
)
{
// mixed types, always converts to f32, does computation in f32, and returns
// float
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
const
float
fa
,
hip_fp8
b
)
{
return
(
fa
+
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
hip_fp8
a
,
const
float
fb
)
{
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
hip_fp8
a
,
const
float
fb
)
{
return
(
float
(
a
)
+
fb
);
}
inline
HIP_FP8_HOST_DEVICE
hip_fp8
operator
+
(
hip_fp8
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
hip_fp8
operator
+
(
hip_fp8
a
,
hip_fp8
b
)
{
return
hip_fp8
(
float
(
a
)
+
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
hip_fp8
&
operator
+=
(
hip_fp8
&
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
hip_fp8
&
operator
+=
(
hip_fp8
&
a
,
hip_fp8
b
)
{
return
a
=
hip_fp8
(
float
(
a
)
+
float
(
b
));
}
// overloading multiplication, always returns float,
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
hip_fp8
b
)
{
return
float
(
a
)
*
float
(
b
);
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
float
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
float
a
,
hip_fp8
b
)
{
return
(
a
*
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
float
b
)
{
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
float
b
)
{
return
(
float
(
a
)
*
b
);
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
int32_t
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
int32_t
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
double
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
double
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
}
// overloading for compare
inline
HIP_FP8_HOST_DEVICE
bool
operator
==
(
hip_fp8
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
bool
operator
==
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
==
b
.
data
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
!=
(
hip_fp8
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
bool
operator
!=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
!=
b
.
data
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
>=
(
hip_fp8
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
bool
operator
>=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>=
static_cast
<
float
>
(
b
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
>
(
hip_fp8
a
,
hip_fp8
b
)
{
inline
HIP_FP8_HOST_DEVICE
bool
operator
>
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
}
csrc/quantization/fp8/amd/hip_float8_impl.h
View file @
5f6d10c1
#pragma once
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#if defined(__HIPCC__) && \
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#endif
#ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#else
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#endif
namespace
hip_fp8_impl
{
namespace
hip_fp8_impl
{
#ifdef __HIP__MI300__
HIP_FP8_DEVICE
uint8_t
to_fp8_from_fp32
(
float
v
)
{
HIP_FP8_DEVICE
uint8_t
to_fp8_from_fp32
(
float
v
)
{
uint8_t
i8data
;
union
{
float
fval
;
...
...
@@ -30,7 +29,8 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
uint32_t
ival
=
0
;
val
.
fval
=
v
;
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
...
...
@@ -43,20 +43,14 @@ HIP_FP8_DEVICE uint8_t to_fp8_from_fp32(float v)
}
#endif // __HIP__MI300__
HIP_FP8_HOST
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
HIP_FP8_HOST
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
HIP_FP8_DEVICE
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
#endif
template
<
int
we
,
int
wm
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
HIP_FP8_HOST_DEVICE
uint8_t
to_float8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
HIP_FP8_HOST_DEVICE
uint8_t
to_float8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
#ifdef __HIPCC__
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
#else
...
...
@@ -130,7 +124,8 @@ HIP_FP8_HOST_DEVICE uint8_t to_float8(T _x, bool stoch = false, uint32_t rng = 0
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
...
...
@@ -146,20 +141,22 @@ are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
range. For example fp8 nanoo mode, denormal exponent is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no
difference
// for this case,
// act_exponent could be
larger. Just that it does not need shift mantissa
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no
// difference for this case, act_exponent could be
//
larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into mantissa
}
...
...
@@ -181,13 +178,16 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit that
// is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit
// that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
)
{
...
...
@@ -222,8 +222,7 @@ where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
}
template
<
int
we
,
int
wm
,
typename
T
=
float
,
bool
negative_zero_nan
=
true
>
inline
HIP_FP8_HOST_DEVICE
T
from_float8
(
uint8_t
x
)
{
inline
HIP_FP8_HOST_DEVICE
T
from_float8
(
uint8_t
x
)
{
#ifdef __HIPCC__
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
#else
...
...
@@ -285,7 +284,8 @@ inline HIP_FP8_HOST_DEVICE T from_float8(uint8_t x)
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
...
...
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment