Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
vllm_cscc
Commits
5f6d10c1
Unverified
Commit
5f6d10c1
authored
May 22, 2024
by
Michael Goin
Committed by
GitHub
May 22, 2024
Browse files
[CI/Build] Enforce style for C++ and CUDA code with `clang-format` (#4722)
parent
9b9a10d6
Changes
64
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1701 additions
and
1872 deletions
+1701
-1872
csrc/cuda_utils_kernels.cu
csrc/cuda_utils_kernels.cu
+17
-23
csrc/custom_all_reduce.cu
csrc/custom_all_reduce.cu
+27
-28
csrc/custom_all_reduce.cuh
csrc/custom_all_reduce.cuh
+51
-54
csrc/custom_all_reduce_test.cu
csrc/custom_all_reduce_test.cu
+19
-19
csrc/dispatch_utils.h
csrc/dispatch_utils.h
+20
-22
csrc/layernorm_kernels.cu
csrc/layernorm_kernels.cu
+121
-121
csrc/moe/moe_ops.cpp
csrc/moe/moe_ops.cpp
+2
-1
csrc/moe/moe_ops.h
csrc/moe/moe_ops.h
+3
-5
csrc/moe_align_block_size_kernels.cu
csrc/moe_align_block_size_kernels.cu
+110
-101
csrc/ops.h
csrc/ops.h
+121
-209
csrc/pos_encoding_kernels.cu
csrc/pos_encoding_kernels.cu
+103
-126
csrc/pybind.cpp
csrc/pybind.cpp
+56
-86
csrc/quantization/aqlm/gemm_kernels.cu
csrc/quantization/aqlm/gemm_kernels.cu
+211
-325
csrc/quantization/awq/dequantize.cuh
csrc/quantization/awq/dequantize.cuh
+76
-62
csrc/quantization/awq/gemm_kernels.cu
csrc/quantization/awq/gemm_kernels.cu
+357
-254
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
+19
-19
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
+11
-11
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
+24
-23
csrc/quantization/fp8/amd/hip_float8.h
csrc/quantization/fp8/amd/hip_float8.h
+93
-123
csrc/quantization/fp8/amd/hip_float8_impl.h
csrc/quantization/fp8/amd/hip_float8_impl.h
+260
-260
No files found.
csrc/cuda_utils_kernels.cu
View file @
5f6d10c1
...
...
@@ -2,34 +2,28 @@
#include <hip/hip_runtime.h>
#include <hip/hip_runtime_api.h>
#endif
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
device
,
value
;
if
(
device_id
<
0
)
{
cudaGetDevice
(
&
device
);
}
else
{
device
=
device_id
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
return
value
;
int
get_device_attribute
(
int
attribute
,
int
device_id
)
{
int
device
,
value
;
if
(
device_id
<
0
)
{
cudaGetDevice
(
&
device
);
}
else
{
device
=
device_id
;
}
cudaDeviceGetAttribute
(
&
value
,
static_cast
<
cudaDeviceAttr
>
(
attribute
),
device
);
return
value
;
}
int
get_max_shared_memory_per_block_device_attribute
(
int
device_id
)
{
int
attribute
;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
int
get_max_shared_memory_per_block_device_attribute
(
int
device_id
)
{
int
attribute
;
// https://docs.nvidia.com/cuda/cuda-runtime-api/group__CUDART__TYPES.html
// cudaDevAttrMaxSharedMemoryPerBlockOptin = 97 if not is_hip() else 74
#ifdef USE_ROCM
attribute
=
hipDeviceAttributeMaxSharedMemoryPerBlock
;
attribute
=
hipDeviceAttributeMaxSharedMemoryPerBlock
;
#else
attribute
=
cudaDevAttrMaxSharedMemoryPerBlockOptin
;
attribute
=
cudaDevAttrMaxSharedMemoryPerBlockOptin
;
#endif
return
get_device_attribute
(
attribute
,
device_id
);
return
get_device_attribute
(
attribute
,
device_id
);
}
csrc/custom_all_reduce.cu
View file @
5f6d10c1
...
...
@@ -7,11 +7,11 @@
// fake pointer type
using
fptr_t
=
uint64_t
;
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
static_assert
(
sizeof
(
void
*
)
==
sizeof
(
fptr_t
));
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
)
{
int
world_size
=
offsets
.
size
();
if
(
world_size
>
8
)
...
...
@@ -29,7 +29,7 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
std
::
memcpy
(
&
ipc_handles
[
i
],
handles
[
i
].
data
(),
sizeof
(
cudaIpcMemHandle_t
));
}
return
(
fptr_t
)
new
vllm
::
CustomAllreduce
(
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
reinterpret_cast
<
vllm
::
Signal
*>
(
meta
.
data_ptr
()),
rank_data
.
data_ptr
(),
rank_data
.
numel
(),
ipc_handles
,
offsets
,
rank
,
full_nvlink
);
}
...
...
@@ -49,13 +49,13 @@ fptr_t init_custom_ar(torch::Tensor &meta, torch::Tensor &rank_data,
* 5. A[None].expand(2, -1, -1, -1): Not OK
* 6. A[:, 1:, 1:]: Not OK
*/
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
bool
_is_weak_contiguous
(
torch
::
Tensor
&
t
)
{
return
t
.
is_contiguous
()
||
(
t
.
storage
().
nbytes
()
-
t
.
storage_offset
()
*
t
.
element_size
()
==
t
.
numel
()
*
t
.
element_size
());
}
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
full_nvlink
)
{
auto
inp_size
=
inp
.
numel
()
*
inp
.
element_size
();
// custom allreduce requires input byte size to be multiples of 16
...
...
@@ -67,28 +67,27 @@ bool should_custom_ar(torch::Tensor &inp, int max_size, int world_size,
return
false
;
}
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
void
_all_reduce
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
,
cudaStream_t
stream
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
TORCH_CHECK
(
_is_weak_contiguous
(
out
));
switch
(
out
.
scalar_type
())
{
case
at
::
ScalarType
::
Float
:
{
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
fa
->
allreduce
<
float
>
(
stream
,
reinterpret_cast
<
float
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
float
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
case
at
::
ScalarType
::
Half
:
{
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
fa
->
allreduce
<
half
>
(
stream
,
reinterpret_cast
<
half
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
half
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
case
at
::
ScalarType
::
BFloat16
:
{
fa
->
allreduce
<
nv_bfloat16
>
(
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
stream
,
reinterpret_cast
<
nv_bfloat16
*>
(
inp
.
data_ptr
()),
reinterpret_cast
<
nv_bfloat16
*>
(
out
.
data_ptr
()),
out
.
numel
());
break
;
}
#endif
...
...
@@ -98,7 +97,7 @@ void _all_reduce(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out,
}
}
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
TORCH_CHECK_EQ
(
inp
.
scalar_type
(),
out
.
scalar_type
());
...
...
@@ -106,8 +105,8 @@ void all_reduce_reg(fptr_t _fa, torch::Tensor &inp, torch::Tensor &out) {
_all_reduce
(
_fa
,
inp
,
out
,
stream
);
}
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
)
{
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
inp
));
auto
stream
=
c10
::
cuda
::
getCurrentCUDAStream
().
stream
();
...
...
@@ -122,27 +121,27 @@ void all_reduce_unreg(fptr_t _fa, torch::Tensor &inp, torch::Tensor ®_buffer,
}
void
dispose
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
delete
fa
;
}
int
meta_size
()
{
return
sizeof
(
vllm
::
Signal
);
}
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_buffer
(
handles
,
offsets
,
t
.
data_ptr
());
}
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
return
fa
->
get_graph_buffer_ipc_meta
();
}
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
fa
=
reinterpret_cast
<
vllm
::
CustomAllreduce
*>
(
_fa
);
fa
->
register_graph_buffers
(
handles
,
offsets
);
}
csrc/custom_all_reduce.cuh
View file @
5f6d10c1
...
...
@@ -31,9 +31,9 @@ struct Signal {
alignas
(
128
)
uint32_t
end
[
kMaxBlocks
][
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankData
{
const
void
*
__restrict__
ptrs
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
struct
__align__
(
16
)
RankSignals
{
volatile
Signal
*
signals
[
8
];
};
// like std::array, but aligned
template
<
typename
T
,
int
sz
>
...
...
@@ -68,11 +68,11 @@ DINLINE half downcast_s(float val) {
// scalar add functions
// for some reason when compiling with Pytorch, the + operator for half and
// bfloat is disabled so we call the intrinsics directly
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
DINLINE
half
&
assign_add
(
half
&
a
,
half
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
DINLINE
float
&
assign_add
(
float
&
a
,
float
b
)
{
return
a
+=
b
;
}
#if (__CUDA_ARCH__ >= 800 || !defined(__CUDA_ARCH__))
DINLINE
float
upcast_s
(
nv_bfloat16
val
)
{
return
__bfloat162float
(
val
);
}
...
...
@@ -80,14 +80,14 @@ template <>
DINLINE
nv_bfloat16
downcast_s
(
float
val
)
{
return
__float2bfloat16
(
val
);
}
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
DINLINE
nv_bfloat16
&
assign_add
(
nv_bfloat16
&
a
,
nv_bfloat16
b
)
{
a
=
__hadd
(
a
,
b
);
return
a
;
}
#endif
template
<
typename
T
,
int
N
>
DINLINE
array_t
<
T
,
N
>
&
packed_assign_add
(
array_t
<
T
,
N
>
&
a
,
array_t
<
T
,
N
>
b
)
{
DINLINE
array_t
<
T
,
N
>&
packed_assign_add
(
array_t
<
T
,
N
>&
a
,
array_t
<
T
,
N
>
b
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
N
;
i
++
)
{
assign_add
(
a
.
data
[
i
],
b
.
data
[
i
]);
...
...
@@ -128,7 +128,7 @@ DINLINE O downcast(array_t<float, O::size> val) {
// prior memory accesses. Note: volatile writes will not be reordered against
// other volatile writes.
template
<
int
ngpus
>
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
DINLINE
void
start_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
...
...
@@ -137,8 +137,7 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
start
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
])
;
while
(
!
self_sg
->
start
[
blockIdx
.
x
][
threadIdx
.
x
]);
}
__syncthreads
();
}
...
...
@@ -147,13 +146,13 @@ DINLINE void start_sync(const RankSignals &sg, volatile Signal *self_sg,
// barrier in the all reduce kernel. If it's the final synchronization barrier,
// we don't need to make any visibility guarantees for prior memory accesses.
template
<
int
ngpus
,
bool
final_sync
=
false
>
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
DINLINE
void
end_sync
(
const
RankSignals
&
sg
,
volatile
Signal
*
self_sg
,
int
rank
)
{
__syncthreads
();
// eliminate the case that prior writes are not visible after signals become
// visible. Note that I did not managed to make this happen through a lot of
// testing. Might be the case that hardware provides stronger guarantee than
// the memory model.
// the memory model.
if
constexpr
(
!
final_sync
)
__threadfence_system
();
if
(
threadIdx
.
x
<
ngpus
)
{
// reset flag for next time
...
...
@@ -162,14 +161,13 @@ DINLINE void end_sync(const RankSignals &sg, volatile Signal *self_sg,
// Latency = 1 p2p write
sg
.
signals
[
threadIdx
.
x
]
->
end
[
blockIdx
.
x
][
rank
]
=
1
;
// wait until we got true from all ranks
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
])
;
while
(
!
self_sg
->
end
[
blockIdx
.
x
][
threadIdx
.
x
]);
}
if
constexpr
(
!
final_sync
)
__syncthreads
();
}
template
<
typename
P
,
int
ngpus
,
typename
A
>
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
DINLINE
P
packed_reduce
(
const
P
*
ptrs
[],
int
idx
)
{
A
tmp
=
upcast
(
ptrs
[
0
][
idx
]);
#pragma unroll
for
(
int
i
=
1
;
i
<
ngpus
;
i
++
)
{
...
...
@@ -180,8 +178,8 @@ DINLINE P packed_reduce(const P *ptrs[], int idx) {
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
cross_device_reduce_1stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
using
P
=
typename
packed_t
<
T
>::
P
;
using
A
=
typename
packed_t
<
T
>::
A
;
...
...
@@ -192,21 +190,20 @@ __global__ void __launch_bounds__(512, 1)
// do the actual reduction
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
((
P
*
)
result
)[
idx
]
=
packed_reduce
<
P
,
ngpus
,
A
>
((
const
P
**
)
&
dp
.
ptrs
[
0
],
idx
);
}
end_sync
<
ngpus
,
true
>
(
sg
,
self_sg
,
rank
);
}
template
<
typename
P
>
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
DINLINE
P
*
get_tmp_buf
(
volatile
Signal
*
sg
)
{
return
(
P
*
)(((
Signal
*
)
sg
)
+
1
);
}
template
<
typename
T
,
int
ngpus
>
__global__
void
__launch_bounds__
(
512
,
1
)
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
cross_device_reduce_2stage
(
RankData
*
_dp
,
RankSignals
sg
,
volatile
Signal
*
self_sg
,
T
*
__restrict__
result
,
int
rank
,
int
size
)
{
int
tid
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
int
stride
=
gridDim
.
x
*
blockDim
.
x
;
...
...
@@ -216,12 +213,12 @@ __global__ void __launch_bounds__(512, 1)
int
start
=
rank
*
part
;
int
end
=
rank
==
ngpus
-
1
?
size
:
start
+
part
;
int
largest_part
=
part
+
size
%
ngpus
;
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
const
P
*
ptrs
[
ngpus
];
P
*
tmps
[
ngpus
];
#pragma unroll
for
(
int
i
=
0
;
i
<
ngpus
;
i
++
)
{
int
target
=
(
rank
+
i
)
%
ngpus
;
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
ptrs
[
i
]
=
(
const
P
*
)
_dp
->
ptrs
[
target
];
tmps
[
i
]
=
get_tmp_buf
<
P
>
(
sg
.
signals
[
target
]);
}
auto
tmp_out
=
tmps
[
0
];
...
...
@@ -243,7 +240,7 @@ __global__ void __launch_bounds__(512, 1)
int
gather_from_rank
=
((
rank
+
i
)
%
ngpus
);
if
(
gather_from_rank
==
ngpus
-
1
||
idx
<
part
)
{
int
dst_idx
=
gather_from_rank
*
part
+
idx
;
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
((
P
*
)
result
)[
dst_idx
]
=
tmps
[
i
][
idx
];
}
}
}
...
...
@@ -261,14 +258,14 @@ class CustomAllreduce {
// below are device pointers
RankSignals
sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
std
::
unordered_map
<
void
*
,
RankData
*>
buffers_
;
Signal
*
self_sg_
;
// stores the registered device pointers from all ranks
RankData
*
d_rank_data_base_
,
*
d_rank_data_end_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
std
::
vector
<
void
*>
graph_unreg_buffers_
;
// a map from IPC handles to opened IPC pointers
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
std
::
map
<
IPC_KEY
,
char
*>
ipc_handles_
;
/**
* meta is a pointer to device metadata and temporary buffer for allreduce.
...
...
@@ -279,22 +276,22 @@ class CustomAllreduce {
* note: this class does not own any device memory. Any required buffers
* are passed in from the constructor
*/
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
CustomAllreduce
(
Signal
*
meta
,
void
*
rank_data
,
size_t
rank_data_sz
,
const
cudaIpcMemHandle_t
*
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
=
true
)
:
rank_
(
rank
),
world_size_
(
offsets
.
size
()),
full_nvlink_
(
full_nvlink
),
self_sg_
(
meta
),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_base_
(
reinterpret_cast
<
RankData
*>
(
rank_data
)),
d_rank_data_end_
(
d_rank_data_base_
+
rank_data_sz
/
sizeof
(
RankData
))
{
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
Signal
*
rank_sg
;
Signal
*
rank_sg
;
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
char
*
handle
=
open_ipc_handle
(
&
handles
[
i
]);
handle
+=
offsets
[
i
];
rank_sg
=
(
Signal
*
)
handle
;
rank_sg
=
(
Signal
*
)
handle
;
}
else
{
rank_sg
=
self_sg_
;
}
...
...
@@ -302,13 +299,13 @@ class CustomAllreduce {
}
}
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
char
*
open_ipc_handle
(
const
void
*
ipc_handle
)
{
auto
[
it
,
new_handle
]
=
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
ipc_handles_
.
insert
({
*
((
IPC_KEY
*
)
ipc_handle
),
nullptr
});
if
(
new_handle
)
{
char
*
ipc_ptr
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
char
*
ipc_ptr
;
CUDACHECK
(
cudaIpcOpenMemHandle
((
void
**
)
&
ipc_ptr
,
*
((
const
cudaIpcMemHandle_t
*
)
ipc_handle
),
cudaIpcMemLazyEnablePeerAccess
));
it
->
second
=
ipc_ptr
;
}
...
...
@@ -323,7 +320,7 @@ class CustomAllreduce {
std
::
vector
<
int64_t
>
offsets
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
ptr
=
graph_unreg_buffers_
[
i
];
void
*
base_ptr
;
void
*
base_ptr
;
// note: must share the base address of each allocation, or we get wrong
// address
if
(
cuPointerGetAttribute
(
&
base_ptr
,
...
...
@@ -331,8 +328,8 @@ class CustomAllreduce {
(
CUdeviceptr
)
ptr
)
!=
CUDA_SUCCESS
)
throw
std
::
runtime_error
(
"failed to get pointer attr"
);
CUDACHECK
(
cudaIpcGetMemHandle
(
(
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
(
cudaIpcMemHandle_t
*
)
&
handles
[
i
*
handle_sz
],
base_ptr
));
offsets
[
i
]
=
((
char
*
)
ptr
)
-
((
char
*
)
base_ptr
);
}
return
std
::
make_pair
(
handles
,
offsets
);
}
...
...
@@ -344,13 +341,13 @@ class CustomAllreduce {
std
::
to_string
(
d_rank_data_base_
+
num
-
d_rank_data_end_
));
}
void
register_buffer
(
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
void
*
self
)
{
void
register_buffer
(
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
void
*
self
)
{
check_rank_data_capacity
();
RankData
data
;
for
(
int
i
=
0
;
i
<
world_size_
;
i
++
)
{
if
(
i
!=
rank_
)
{
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
char
*
handle
=
open_ipc_handle
(
handles
[
i
].
data
());
handle
+=
offsets
[
i
];
data
.
ptrs
[
i
]
=
handle
;
}
else
{
...
...
@@ -371,17 +368,17 @@ class CustomAllreduce {
// got a different address. IPC handles have internal reference counting
// mechanism so overhead should be small.
void
register_graph_buffers
(
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
)
{
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
)
{
auto
num_buffers
=
graph_unreg_buffers_
.
size
();
check_rank_data_capacity
(
num_buffers
);
std
::
vector
<
RankData
>
rank_data
(
num_buffers
);
for
(
int
i
=
0
;
i
<
num_buffers
;
i
++
)
{
auto
self_ptr
=
graph_unreg_buffers_
[
i
];
auto
&
rd
=
rank_data
[
i
];
auto
&
rd
=
rank_data
[
i
];
for
(
int
j
=
0
;
j
<
world_size_
;
j
++
)
{
if
(
j
!=
rank_
)
{
char
*
handle
=
char
*
handle
=
open_ipc_handle
(
&
handles
[
j
][
i
*
sizeof
(
cudaIpcMemHandle_t
)]);
handle
+=
offsets
[
j
][
i
];
rd
.
ptrs
[
j
]
=
handle
;
...
...
@@ -405,7 +402,7 @@ class CustomAllreduce {
* will cause contention on NVLink bus.
*/
template
<
typename
T
>
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
void
allreduce
(
cudaStream_t
stream
,
T
*
input
,
T
*
output
,
int
size
,
int
threads
=
512
,
int
block_limit
=
36
)
{
auto
d
=
packed_t
<
T
>::
P
::
size
;
if
(
size
%
d
!=
0
)
...
...
@@ -418,7 +415,7 @@ class CustomAllreduce {
std
::
to_string
(
kMaxBlocks
)
+
". Got "
+
std
::
to_string
(
block_limit
));
RankData
*
ptrs
;
RankData
*
ptrs
;
cudaStreamCaptureStatus
status
;
CUDACHECK
(
cudaStreamIsCapturing
(
stream
,
&
status
));
if
(
status
==
cudaStreamCaptureStatusActive
)
{
...
...
csrc/custom_all_reduce_test.cu
View file @
5f6d10c1
...
...
@@ -48,7 +48,7 @@ __global__ void dummy_kernel() {
}
template
<
typename
T
>
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
__global__
void
set_data
(
T
*
data
,
int
size
,
int
myRank
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
data
[
idx
]
=
myRank
*
0.11
f
;
...
...
@@ -56,8 +56,8 @@ __global__ void set_data(T *data, int size, int myRank) {
}
template
<
typename
T
>
__global__
void
convert_data
(
const
T
*
data1
,
const
T
*
data2
,
double
*
fdata1
,
double
*
fdata2
,
int
size
)
{
__global__
void
convert_data
(
const
T
*
data1
,
const
T
*
data2
,
double
*
fdata1
,
double
*
fdata2
,
int
size
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
fdata1
[
idx
]
=
data1
[
idx
];
...
...
@@ -65,7 +65,7 @@ __global__ void convert_data(const T *data1, const T *data2, double *fdata1,
}
}
__global__
void
init_rand
(
curandState_t
*
state
,
int
size
,
int
nRanks
)
{
__global__
void
init_rand
(
curandState_t
*
state
,
int
size
,
int
nRanks
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
...
...
@@ -75,7 +75,7 @@ __global__ void init_rand(curandState_t *state, int size, int nRanks) {
}
template
<
typename
T
>
__global__
void
gen_data
(
curandState_t
*
state
,
T
*
data
,
double
*
ground_truth
,
__global__
void
gen_data
(
curandState_t
*
state
,
T
*
data
,
double
*
ground_truth
,
int
myRank
,
int
nRanks
,
int
size
)
{
for
(
int
idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
idx
<
size
;
idx
+=
gridDim
.
x
*
blockDim
.
x
)
{
...
...
@@ -91,9 +91,9 @@ __global__ void gen_data(curandState_t *state, T *data, double *ground_truth,
}
template
<
typename
T
>
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
void
run
(
int
myRank
,
int
nRanks
,
ncclComm_t
&
comm
,
int
threads
,
int
block_limit
,
int
data_size
,
bool
performance_test
)
{
T
*
result
;
T
*
result
;
cudaStream_t
stream
;
CUDACHECK
(
cudaStreamCreateWithFlags
(
&
stream
,
cudaStreamNonBlocking
));
CUDACHECK
(
cudaMalloc
(
&
result
,
data_size
*
sizeof
(
T
)));
...
...
@@ -101,8 +101,8 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
cudaIpcMemHandle_t
self_data_handle
;
cudaIpcMemHandle_t
data_handles
[
8
];
vllm
::
Signal
*
buffer
;
T
*
self_data_copy
;
vllm
::
Signal
*
buffer
;
T
*
self_data_copy
;
/**
* Allocate IPC buffer
*
...
...
@@ -125,22 +125,22 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
MPI_BYTE
,
data_handles
,
sizeof
(
cudaIpcMemHandle_t
),
MPI_BYTE
,
MPI_COMM_WORLD
));
void
*
rank_data
;
void
*
rank_data
;
size_t
rank_data_sz
=
16
*
1024
*
1024
;
CUDACHECK
(
cudaMalloc
(
&
rank_data
,
rank_data_sz
));
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
0
);
vllm
::
CustomAllreduce
fa
(
buffer
,
rank_data
,
rank_data_sz
,
data_handles
,
offsets
,
myRank
);
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
auto
*
self_data
=
reinterpret_cast
<
T
*>
(
reinterpret_cast
<
char
*>
(
buffer
)
+
sizeof
(
vllm
::
Signal
)
+
data_size
*
sizeof
(
T
));
// hack buffer registration
{
std
::
vector
<
std
::
string
>
handles
;
handles
.
reserve
(
nRanks
);
for
(
int
i
=
0
;
i
<
nRanks
;
i
++
)
{
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
char
*
begin
=
(
char
*
)
&
data_handles
[
i
];
char
*
end
=
(
char
*
)
&
data_handles
[
i
+
1
];
handles
.
emplace_back
(
begin
,
end
);
}
std
::
vector
<
int64_t
>
offsets
(
nRanks
,
...
...
@@ -148,9 +148,9 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
fa
.
register_buffer
(
handles
,
offsets
,
self_data
);
}
double
*
ground_truth
;
double
*
ground_truth
;
CUDACHECK
(
cudaMallocHost
(
&
ground_truth
,
data_size
*
sizeof
(
double
)));
curandState_t
*
states
;
curandState_t
*
states
;
CUDACHECK
(
cudaMalloc
(
&
states
,
sizeof
(
curandState_t
)
*
nRanks
*
data_size
));
init_rand
<<<
108
,
1024
,
0
,
stream
>>>
(
states
,
data_size
,
nRanks
);
gen_data
<
T
><<<
108
,
1024
,
0
,
stream
>>>
(
states
,
self_data
,
ground_truth
,
myRank
,
...
...
@@ -287,7 +287,7 @@ void run(int myRank, int nRanks, ncclComm_t &comm, int threads, int block_limit,
CUDACHECK
(
cudaStreamDestroy
(
stream
));
}
int
main
(
int
argc
,
char
**
argv
)
{
int
main
(
int
argc
,
char
**
argv
)
{
int
nRanks
,
myRank
;
MPICHECK
(
MPI_Init
(
&
argc
,
&
argv
));
MPICHECK
(
MPI_Comm_rank
(
MPI_COMM_WORLD
,
&
myRank
));
...
...
@@ -296,7 +296,7 @@ int main(int argc, char **argv) {
ncclUniqueId
id
;
ncclComm_t
comm
;
if
(
myRank
==
0
)
ncclGetUniqueId
(
&
id
);
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPICHECK
(
MPI_Bcast
(
static_cast
<
void
*>
(
&
id
),
sizeof
(
id
),
MPI_BYTE
,
0
,
MPI_COMM_WORLD
));
NCCLCHECK
(
ncclCommInitRank
(
&
comm
,
nRanks
,
id
,
myRank
));
...
...
csrc/dispatch_utils.h
View file @
5f6d10c1
...
...
@@ -6,32 +6,30 @@
#include <torch/extension.h>
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...)
\
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
\
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...)
\
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
\
#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...)
\
AT_DISPATCH_SWITCH(
\
TYPE, NAME,
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...)
\
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__)
\
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__)
\
#define VLLM_DISPATCH_FLOATING_AND_BYTE_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(
TYPE, NAME,
\
VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_CASE_INTEGRAL_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Byte, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Short, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Int, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Long, __VA_ARGS__)
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
#define VLLM_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH(TYPE, NAME, VLLM_DISPATCH_CASE_INTEGRAL_TYPES(__VA_ARGS__))
csrc/layernorm_kernels.cu
View file @
5f6d10c1
...
...
@@ -11,26 +11,24 @@
#include <hip/hip_bf16.h>
#include <hip/hip_fp16.h>
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
using
__nv_bfloat16
=
__hip_bfloat16
;
using
__nv_bfloat162
=
__hip_bfloat162
;
#endif
namespace
vllm
{
// TODO(woosuk): Further optimize this kernel.
template
<
typename
scalar_t
>
template
<
typename
scalar_t
>
__global__
void
rms_norm_kernel
(
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
scalar_t
*
__restrict__
out
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
const
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
variance
+=
x
*
x
;
}
variance
=
blockReduceSum
<
float
>
(
variance
);
...
...
@@ -40,12 +38,12 @@ __global__ void rms_norm_kernel(
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
float
x
=
(
float
)
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
out
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
/* Converter structs for the conversion from torch types to HIP/CUDA types,
and the associated type conversions within HIP/CUDA. These helpers need
to be implemented for now because the relevant type conversion
...
...
@@ -54,51 +52,68 @@ __global__ void rms_norm_kernel(
Each struct should have the member static constexpr bool `exists`:
If false, the optimized kernel is not used for the corresponding torch type.
If true, the struct should be fully defined as shown in the examples below.
If true, the struct should be fully defined as shown in the examples below.
*/
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
template
<
typename
torch_type
>
struct
_typeConvert
{
static
constexpr
bool
exists
=
false
;
};
#if defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
// CUDA < 12.0 runs into issues with packed type conversion
template
<
>
template
<
>
struct
_typeConvert
<
c10
::
Half
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__half
;
using
packed_hip_type
=
__half2
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__half2float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__half22float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2half_rn
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22half2_rn
(
x
);
}
};
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
// CUDA_ARCH < 800 does not have BF16 support
// TODO: Add in ROCm support once public headers handle bf16 maturely
template
<
>
template
<
>
struct
_typeConvert
<
c10
::
BFloat16
>
{
static
constexpr
bool
exists
=
true
;
using
hip_type
=
__nv_bfloat16
;
using
packed_hip_type
=
__nv_bfloat162
;
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
__device__
static
inline
float
convert
(
hip_type
x
)
{
return
__bfloat162float
(
x
);
}
__device__
static
inline
float2
convert
(
packed_hip_type
x
)
{
return
__bfloat1622float2
(
x
);
}
__device__
static
inline
hip_type
convert
(
float
x
)
{
return
__float2bfloat16
(
x
);
}
__device__
static
inline
packed_hip_type
convert
(
float2
x
)
{
return
__float22bfloat162_rn
(
x
);
}
};
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >= 12000))
#endif // defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 800
#endif // defined(USE_ROCM) || (defined(CUDA_VERSION) && (CUDA_VERSION >=
// 12000))
/* Vector POD struct to generate vectorized and packed FP16/BF16 ops
for appropriate specializations of fused_add_rms_norm_kernel.
Only functions that are necessary in that kernel are implemented.
Alignment to 16 bytes is required to use 128-bit global memory ops.
*/
template
<
typename
scalar_t
,
int
width
>
template
<
typename
scalar_t
,
int
width
>
struct
alignas
(
16
)
_f16Vec
{
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
/* Not theoretically necessary that width is a power of 2 but should
almost always be the case for optimization purposes */
static_assert
(
width
>
0
&&
(
width
&
(
width
-
1
))
==
0
,
"Width is not a positive power of 2!"
);
using
Converter
=
_typeConvert
<
scalar_t
>
;
...
...
@@ -108,51 +123,49 @@ struct alignas(16) _f16Vec {
__device__
_f16Vec
&
operator
+=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
+=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
+=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
_f16Vec
<
scalar_t
,
width
>&
other
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
T2
temp
{
data
[
i
],
data
[
i
+
1
]};
temp
*=
T2
{
other
.
data
[
i
],
other
.
data
[
i
+
1
]};
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
data
[
i
]
*=
other
.
data
[
i
];
}
return
*
this
;
}
__device__
_f16Vec
&
operator
*=
(
const
float
scale
)
{
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float2
temp_f
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
temp_f
.
x
*=
scale
;
temp_f
.
y
*=
scale
;
T2
temp
=
Converter
::
convert
(
temp_f
);
data
[
i
]
=
temp
.
x
;
data
[
i
+
1
]
=
temp
.
y
;
data
[
i
+
1
]
=
temp
.
y
;
}
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
temp
=
Converter
::
convert
(
data
[
i
])
*
scale
;
data
[
i
]
=
Converter
::
convert
(
temp
);
...
...
@@ -164,13 +177,13 @@ struct alignas(16) _f16Vec {
__device__
float
sum_squares
()
const
{
float
result
=
0.0
f
;
if
constexpr
(
width
%
2
==
0
)
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
i
+=
2
)
{
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
float2
z
=
Converter
::
convert
(
T2
{
data
[
i
],
data
[
i
+
1
]});
result
+=
z
.
x
*
z
.
x
+
z
.
y
*
z
.
y
;
}
}
else
{
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
width
;
++
i
)
{
float
x
=
Converter
::
convert
(
data
[
i
]);
result
+=
x
*
x
;
...
...
@@ -184,15 +197,13 @@ struct alignas(16) _f16Vec {
Additional optimizations we can make in this case are
packed and vectorized operations, which help with the
memory latency bottleneck. */
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
>
0
)
&&
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
// Sanity checks on our vector struct and type-punned pointer arithmetic
static_assert
(
std
::
is_pod_v
<
_f16Vec
<
scalar_t
,
width
>>
);
static_assert
(
sizeof
(
_f16Vec
<
scalar_t
,
width
>
)
==
sizeof
(
scalar_t
)
*
width
);
...
...
@@ -203,9 +214,12 @@ __global__ std::enable_if_t<
/* These and the argument pointers are all declared `restrict` as they are
not aliased in practice. Argument pointers should not be dereferenced
in this kernel as that would be undefined behavior */
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
auto
*
__restrict__
input_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
input
);
auto
*
__restrict__
residual_v
=
reinterpret_cast
<
_f16Vec
<
scalar_t
,
width
>*>
(
residual
);
auto
*
__restrict__
weight_v
=
reinterpret_cast
<
const
_f16Vec
<
scalar_t
,
width
>*>
(
weight
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
vec_hidden_size
;
idx
+=
blockDim
.
x
)
{
int
id
=
blockIdx
.
x
*
vec_hidden_size
+
idx
;
...
...
@@ -215,10 +229,11 @@ __global__ std::enable_if_t<
residual_v
[
id
]
=
temp
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
...
...
@@ -233,52 +248,50 @@ __global__ std::enable_if_t<
}
}
/* Generic fused_add_rms_norm_kernel
The width field is not used here but necessary for other specializations.
*/
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
template
<
typename
scalar_t
,
int
width
>
__global__
std
::
enable_if_t
<
(
width
==
0
)
||
!
_typeConvert
<
scalar_t
>::
exists
>
fused_add_rms_norm_kernel
(
scalar_t
*
__restrict__
input
,
// [..., hidden_size]
scalar_t
*
__restrict__
residual
,
// [..., hidden_size]
const
scalar_t
*
__restrict__
weight
,
// [hidden_size]
const
float
epsilon
,
const
int
num_tokens
,
const
int
hidden_size
)
{
__shared__
float
s_variance
;
float
variance
=
0.0
f
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
scalar_t
z
=
input
[
blockIdx
.
x
*
hidden_size
+
idx
];
z
+=
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
float
x
=
(
float
)
z
;
float
x
=
(
float
)
z
;
variance
+=
x
*
x
;
residual
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
z
;
}
/* Keep the following if-else block in sync with the
calculation of max_block_size in fused_add_rms_norm */
calculation of max_block_size in fused_add_rms_norm */
if
(
num_tokens
<
256
)
{
variance
=
blockReduceSum
<
float
,
1024
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
}
else
variance
=
blockReduceSum
<
float
,
256
>
(
variance
);
if
(
threadIdx
.
x
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
hidden_size
+
epsilon
);
}
__syncthreads
();
for
(
int
idx
=
threadIdx
.
x
;
idx
<
hidden_size
;
idx
+=
blockDim
.
x
)
{
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)
(
x
*
s_variance
))
*
weight
[
idx
];
float
x
=
(
float
)
residual
[
blockIdx
.
x
*
hidden_size
+
idx
];
input
[
blockIdx
.
x
*
hidden_size
+
idx
]
=
((
scalar_t
)(
x
*
s_variance
))
*
weight
[
idx
];
}
}
}
// namespace vllm
}
// namespace vllm
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
void
rms_norm
(
torch
::
Tensor
&
out
,
// [..., hidden_size]
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -286,40 +299,27 @@ void rms_norm(
dim3
block
(
std
::
min
(
hidden_size
,
1024
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
input
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
VLLM_DISPATCH_FLOATING_TYPES
(
input
.
scalar_type
(),
"rms_norm_kernel"
,
[
&
]
{
vllm
::
rms_norm_kernel
<
scalar_t
><<<
grid
,
block
,
0
,
stream
>>>
(
out
.
data_ptr
<
scalar_t
>
(),
input
.
data_ptr
<
scalar_t
>
(),
weight
.
data_ptr
<
scalar_t
>
(),
epsilon
,
num_tokens
,
hidden_size
);
});
}
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), \
"fused_add_rms_norm_kernel", \
[&] { \
vllm::fused_add_rms_norm_kernel \
<scalar_t, width><<<grid, block, 0, stream>>>( \
input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), \
epsilon, \
num_tokens, \
hidden_size); \
});
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
#define LAUNCH_FUSED_ADD_RMS_NORM(width) \
VLLM_DISPATCH_FLOATING_TYPES( \
input.scalar_type(), "fused_add_rms_norm_kernel", [&] { \
vllm::fused_add_rms_norm_kernel<scalar_t, width> \
<<<grid, block, 0, stream>>>(input.data_ptr<scalar_t>(), \
residual.data_ptr<scalar_t>(), \
weight.data_ptr<scalar_t>(), epsilon, \
num_tokens, hidden_size); \
});
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
// [..., hidden_size]
torch
::
Tensor
&
residual
,
// [..., hidden_size]
torch
::
Tensor
&
weight
,
// [hidden_size]
float
epsilon
)
{
int
hidden_size
=
input
.
size
(
-
1
);
int
num_tokens
=
input
.
numel
()
/
hidden_size
;
...
...
@@ -342,8 +342,8 @@ void fused_add_rms_norm(
auto
inp_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
input
.
data_ptr
());
auto
res_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
residual
.
data_ptr
());
auto
wt_ptr
=
reinterpret_cast
<
std
::
uintptr_t
>
(
weight
.
data_ptr
());
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
\
&&
wt_ptr
%
16
==
0
;
bool
ptrs_are_aligned
=
inp_ptr
%
16
==
0
&&
res_ptr
%
16
==
0
&&
wt_ptr
%
16
==
0
;
if
(
ptrs_are_aligned
&&
hidden_size
%
8
==
0
)
{
LAUNCH_FUSED_ADD_RMS_NORM
(
8
);
}
else
{
...
...
csrc/moe/moe_ops.cpp
View file @
5f6d10c1
...
...
@@ -3,5 +3,6 @@
#include <torch/extension.h>
PYBIND11_MODULE
(
TORCH_EXTENSION_NAME
,
m
)
{
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
m
.
def
(
"topk_softmax"
,
&
topk_softmax
,
"Apply topk softmax to the gating outputs."
);
}
csrc/moe/moe_ops.h
View file @
5f6d10c1
...
...
@@ -2,8 +2,6 @@
#include <torch/extension.h>
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
void
topk_softmax
(
torch
::
Tensor
&
topk_weights
,
torch
::
Tensor
&
topk_indices
,
torch
::
Tensor
&
token_expert_indices
,
torch
::
Tensor
&
gating_output
);
csrc/moe_align_block_size_kernels.cu
View file @
5f6d10c1
...
...
@@ -7,119 +7,128 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"
#define CEILDIV(x,y) (((x) + (y) - 1) / (y))
#define CEILDIV(x,
y) (((x) + (y) - 1) / (y))
namespace
vllm
{
namespace
{
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
// don't worry about overflow because num_experts is relatively small
return
row
*
total_col
+
col
;
}
__device__
__forceinline__
int32_t
index
(
int32_t
total_col
,
int32_t
row
,
int32_t
col
)
{
// don't worry about overflow because num_experts is relatively small
return
row
*
total_col
+
col
;
}
}
// namespace
template
<
typename
scalar_t
>
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
int32_t
*
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are assigned
* to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding blocks
* and stores the corresponding expert_id for each block.
*/
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
__global__
void
moe_align_block_size_kernel
(
scalar_t
*
__restrict__
topk_ids
,
int32_t
*
sorted_token_ids
,
int32_t
*
expert_ids
,
int32_t
*
total_tokens_post_pad
,
int32_t
num_experts
,
int32_t
block_size
,
size_t
numel
)
{
const
size_t
tokens_per_thread
=
CEILDIV
(
numel
,
blockDim
.
x
);
const
size_t
start_idx
=
threadIdx
.
x
*
tokens_per_thread
;
extern
__shared__
int32_t
shared_mem
[];
int32_t
*
tokens_cnts
=
shared_mem
;
// 2d tensor with shape (num_experts + 1, num_experts)
int32_t
*
cumsum
=
shared_mem
+
(
num_experts
+
1
)
*
num_experts
;
// 1d tensor with shape (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
i
)]
=
0
;
}
/**
* In the first step we compute token_cnts[thread_index + 1][expert_index],
* which counts how many tokens in the token shard of thread_index are
* assigned to expert expert_index.
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
+
1
,
topk_ids
[
i
])];
}
__syncthreads
();
// For each expert we accumulate the token counts from the different threads.
tokens_cnts
[
index
(
num_experts
,
0
,
threadIdx
.
x
)]
=
0
;
for
(
int
i
=
1
;
i
<=
blockDim
.
x
;
++
i
)
{
tokens_cnts
[
index
(
num_experts
,
i
,
threadIdx
.
x
)]
+=
tokens_cnts
[
index
(
num_experts
,
i
-
1
,
threadIdx
.
x
)];
}
__syncthreads
();
// We accumulate the token counts of all experts in thread 0.
if
(
threadIdx
.
x
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
CEILDIV
(
tokens_cnts
[
index
(
num_experts
,
blockDim
.
x
,
i
-
1
)],
block_size
)
*
block_size
;
}
/**
* Each thread processes a token shard, calculating the index of each token after
* sorting by expert number. Given the example topk_ids = [0,1,2,1,2,3,0,3,4] and
* block_size = 4, then the output would be [0, 6, *, *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *],
* where * represents a padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and tokens_cnts[threadIdx.x][expert_id]
* stores the indices of the tokens processed by the expert with expert_id within
* the current thread's token shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
*
total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
/**
* For each expert, each thread processes the tokens of the corresponding
* blocks and stores the corresponding expert_id for each block.
*/
for
(
int
i
=
cumsum
[
threadIdx
.
x
];
i
<
cumsum
[
threadIdx
.
x
+
1
];
i
+=
block_size
)
{
expert_ids
[
i
/
block_size
]
=
threadIdx
.
x
;
}
/**
* Each thread processes a token shard, calculating the index of each token
* after sorting by expert number. Given the example topk_ids =
* [0,1,2,1,2,3,0,3,4] and block_size = 4, then the output would be [0, 6, *,
* *, 1, 3, *, *, 2, 4, *, *, 5, 7, *, *, 8, *, *, *], where * represents a
* padding value(preset in python).
*/
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
int32_t
expert_id
=
topk_ids
[
i
];
/** The cumsum[expert_id] stores the starting index of the tokens that the
* expert with expert_id needs to process, and
* tokens_cnts[threadIdx.x][expert_id] stores the indices of the tokens
* processed by the expert with expert_id within the current thread's token
* shard.
*/
int32_t
rank_post_pad
=
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)]
+
cumsum
[
expert_id
];
sorted_token_ids
[
rank_post_pad
]
=
i
;
++
tokens_cnts
[
index
(
num_experts
,
threadIdx
.
x
,
expert_id
)];
}
}
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum` tensors
const
int32_t
shared_mem
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
}
// namespace vllm
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
)
{
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_INTEGRAL_TYPES
(
topk_ids
.
scalar_type
(),
"moe_align_block_size_kernel"
,
[
&
]
{
// calc needed amount of shared mem for `tokens_cnts` and `cumsum`
// tensors
const
int32_t
shared_mem
=
((
num_experts
+
1
)
*
num_experts
+
(
num_experts
+
1
))
*
sizeof
(
int32_t
);
// set dynamic shared mem
auto
kernel
=
vllm
::
moe_align_block_size_kernel
<
scalar_t
>
;
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
AT_CUDA_CHECK
(
VLLM_DevFuncAttribute_SET_MaxDynamicSharedMemorySize
(
(
void
*
)
kernel
,
shared_mem
));
kernel
<<<
1
,
num_experts
,
shared_mem
,
stream
>>>
(
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
data_ptr
<
scalar_t
>
(),
sorted_token_ids
.
data_ptr
<
int32_t
>
(),
experts_ids
.
data_ptr
<
int32_t
>
(),
num_tokens_post_pad
.
data_ptr
<
int32_t
>
(),
num_experts
,
block_size
,
topk_ids
.
numel
());
});
});
}
csrc/ops.h
View file @
5f6d10c1
...
...
@@ -2,224 +2,136 @@
#include <torch/extension.h>
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
paged_attention_v1
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
paged_attention_v2
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
exp_sums
,
torch
::
Tensor
&
max_logits
,
torch
::
Tensor
&
tmp_out
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key_cache
,
torch
::
Tensor
&
value_cache
,
int
num_kv_heads
,
float
scale
,
torch
::
Tensor
&
block_tables
,
torch
::
Tensor
&
seq_lens
,
int
block_size
,
int
max_seq_len
,
const
c10
::
optional
<
torch
::
Tensor
>&
alibi_slopes
,
const
std
::
string
&
kv_cache_dtype
,
float
kv_scale
);
void
rms_norm
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
fused_add_rms_norm
(
torch
::
Tensor
&
input
,
torch
::
Tensor
&
residual
,
torch
::
Tensor
&
weight
,
float
epsilon
);
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
);
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
torch
::
Tensor
&
query
,
torch
::
Tensor
&
key
,
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
);
void
silu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_tanh_and_mul
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_new
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
void
gelu_fast
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
);
#ifndef USE_ROCM
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
);
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
);
torch
::
Tensor
marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_24_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_meta
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
);
torch
::
Tensor
gptq_marlin_gemm
(
torch
::
Tensor
&
a
,
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
b_scales
,
torch
::
Tensor
&
g_idx
,
torch
::
Tensor
&
perm
,
torch
::
Tensor
&
workspace
,
int64_t
num_bits
,
int64_t
size_m
,
int64_t
size_n
,
int64_t
size_k
,
bool
is_k_full
);
torch
::
Tensor
gptq_marlin_repack
(
torch
::
Tensor
&
b_q_weight
,
torch
::
Tensor
&
perm
,
int64_t
size_k
,
int64_t
size_n
,
int64_t
num_bits
);
int
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
#endif
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
void
squeezellm_gemm
(
torch
::
Tensor
vec
,
torch
::
Tensor
mat
,
torch
::
Tensor
mul
,
torch
::
Tensor
lookup_table
);
torch
::
Tensor
gptq_gemm
(
torch
::
Tensor
a
,
torch
::
Tensor
b_q_weight
,
torch
::
Tensor
b_gptq_qzeros
,
torch
::
Tensor
b_gptq_scales
,
torch
::
Tensor
b_g_idx
,
bool
use_exllama
,
int
bit
);
void
gptq_shuffle
(
torch
::
Tensor
q_weight
,
torch
::
Tensor
q_perm
,
int
bit
);
void
static_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
dynamic_scaled_fp8_quant
(
torch
::
Tensor
&
out
,
torch
::
Tensor
&
input
,
torch
::
Tensor
&
scale
);
void
moe_align_block_size
(
torch
::
Tensor
topk_ids
,
int
num_experts
,
int
block_size
,
torch
::
Tensor
sorted_token_ids
,
torch
::
Tensor
experts_ids
,
torch
::
Tensor
num_tokens_post_pad
);
#ifndef USE_ROCM
using
fptr_t
=
uint64_t
;
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
,
int
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
fptr_t
init_custom_ar
(
torch
::
Tensor
&
meta
,
torch
::
Tensor
&
rank_data
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
,
int
rank
,
bool
full_nvlink
);
bool
should_custom_ar
(
torch
::
Tensor
&
inp
,
int
max_size
,
int
world_size
,
bool
full_nvlink
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
all_reduce_reg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
out
);
void
all_reduce_unreg
(
fptr_t
_fa
,
torch
::
Tensor
&
inp
,
torch
::
Tensor
&
reg_buffer
,
torch
::
Tensor
&
out
);
void
dispose
(
fptr_t
_fa
);
int
meta_size
();
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
int64_t
>
&
offsets
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>
&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>
&
offsets
);
void
register_buffer
(
fptr_t
_fa
,
torch
::
Tensor
&
t
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
int64_t
>&
offsets
);
std
::
pair
<
std
::
vector
<
uint8_t
>
,
std
::
vector
<
int64_t
>>
get_graph_buffer_ipc_meta
(
fptr_t
_fa
);
void
register_graph_buffers
(
fptr_t
_fa
,
const
std
::
vector
<
std
::
string
>&
handles
,
const
std
::
vector
<
std
::
vector
<
int64_t
>>&
offsets
);
#endif
csrc/pos_encoding_kernels.cu
View file @
5f6d10c1
...
...
@@ -7,14 +7,10 @@
namespace
vllm
{
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_token_rotary_embedding
(
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
scalar_t
*
__restrict__
arr
,
const
scalar_t
*
__restrict__
cos_ptr
,
const
scalar_t
*
__restrict__
sin_ptr
,
int
rot_offset
,
int
embed_dim
)
{
int
x_index
,
y_index
;
scalar_t
cos
,
sin
;
if
(
IS_NEOX
)
{
...
...
@@ -37,19 +33,17 @@ inline __device__ void apply_token_rotary_embedding(
arr
[
y_index
]
=
y
*
cos
+
x
*
sin
;
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
inline
__device__
void
apply_rotary_embedding
(
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
cache_ptr
,
const
int
head_size
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
rot_dim
,
const
int
token_idx
,
const
int64_t
query_stride
,
const
int64_t
key_stride
)
{
const
int
embed_dim
=
rot_dim
/
2
;
const
scalar_t
*
cos_ptr
=
cache_ptr
;
const
scalar_t
*
sin_ptr
=
cache_ptr
+
embed_dim
;
...
...
@@ -59,8 +53,8 @@ inline __device__ void apply_rotary_embedding(
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
query_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
const
int
nk
=
num_kv_heads
*
embed_dim
;
...
...
@@ -68,62 +62,74 @@ inline __device__ void apply_rotary_embedding(
const
int
head_idx
=
i
/
embed_dim
;
const
int64_t
token_head
=
token_idx
*
key_stride
+
head_idx
*
head_size
;
const
int
rot_offset
=
i
%
embed_dim
;
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
apply_token_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
key
+
token_head
,
cos_ptr
,
sin_ptr
,
rot_offset
,
embed_dim
);
}
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
pos
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
template
<
typename
scalar_t
,
bool
IS_NEOX
>
template
<
typename
scalar_t
,
bool
IS_NEOX
>
__global__
void
batched_rotary_embedding_kernel
(
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads, head_size] or [num_tokens, num_heads, head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads, head_size] or [num_tokens, num_kv_heads, head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim // 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len] or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
const
int64_t
*
__restrict__
positions
,
// [batch_size, seq_len] or
// [num_tokens]
scalar_t
*
__restrict__
query
,
// [batch_size, seq_len, num_heads,
// head_size] or [num_tokens, num_heads,
// head_size]
scalar_t
*
__restrict__
key
,
// [batch_size, seq_len, num_kv_heads,
// head_size] or [num_tokens, num_kv_heads,
// head_size]
const
scalar_t
*
__restrict__
cos_sin_cache
,
// [max_position, 2, rot_dim //
// 2]
const
int64_t
*
__restrict__
cos_sin_cache_offsets
,
// [batch_size, seq_len]
// or [num_tokens]
const
int
rot_dim
,
const
int64_t
query_stride
,
const
int64_t
key_stride
,
const
int
num_heads
,
const
int
num_kv_heads
,
const
int
head_size
)
{
// Each thread block is responsible for one token.
const
int
token_idx
=
blockIdx
.
x
;
int64_t
pos
=
positions
[
token_idx
];
int64_t
cos_sin_cache_offset
=
cos_sin_cache_offsets
[
token_idx
];
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
const
scalar_t
*
cache_ptr
=
cos_sin_cache
+
(
cos_sin_cache_offset
+
pos
)
*
rot_dim
;
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
apply_rotary_embedding
<
scalar_t
,
IS_NEOX
>
(
query
,
key
,
cache_ptr
,
head_size
,
num_heads
,
num_kv_heads
,
rot_dim
,
token_idx
,
query_stride
,
key_stride
);
}
}
// namespace vllm
}
// namespace vllm
void
rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
)
{
int64_t
num_tokens
=
query
.
numel
()
/
query
.
size
(
-
1
);
int
rot_dim
=
cos_sin_cache
.
size
(
1
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
...
...
@@ -135,36 +141,21 @@ void rotary_embedding(
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
/*
...
...
@@ -172,14 +163,15 @@ Batched version of rotary embedding, pack multiple LoRAs together
and process in batched manner.
*/
void
batched_rotary_embedding
(
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
torch
::
Tensor
&
positions
,
// [batch_size, seq_len] or [num_tokens]
torch
::
Tensor
&
query
,
// [batch_size, seq_len, num_heads * head_size] or
// [num_tokens, num_heads * head_size]
torch
::
Tensor
&
key
,
// [batch_size, seq_len, num_kv_heads * head_size] or
// [num_tokens, num_kv_heads * head_size]
int
head_size
,
torch
::
Tensor
&
cos_sin_cache
,
// [max_position, rot_dim]
bool
is_neox
,
int
rot_dim
,
torch
::
Tensor
&
cos_sin_cache_offsets
// [num_tokens]
)
{
int64_t
num_tokens
=
cos_sin_cache_offsets
.
size
(
0
);
int
num_heads
=
query
.
size
(
-
1
)
/
head_size
;
...
...
@@ -191,36 +183,21 @@ void batched_rotary_embedding(
dim3
block
(
std
::
min
(
num_heads
*
rot_dim
/
2
,
512
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
query
));
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
><<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
VLLM_DISPATCH_FLOATING_TYPES
(
query
.
scalar_type
(),
"rotary_embedding"
,
[
&
]
{
if
(
is_neox
)
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
true
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
else
{
vllm
::
batched_rotary_embedding_kernel
<
scalar_t
,
false
>
<<<
grid
,
block
,
0
,
stream
>>>
(
positions
.
data_ptr
<
int64_t
>
(),
query
.
data_ptr
<
scalar_t
>
(),
key
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache
.
data_ptr
<
scalar_t
>
(),
cos_sin_cache_offsets
.
data_ptr
<
int64_t
>
(),
rot_dim
,
query_stride
,
key_stride
,
num_heads
,
num_kv_heads
,
head_size
);
}
});
}
csrc/pybind.cpp
View file @
5f6d10c1
...
...
@@ -8,116 +8,87 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
pybind11
::
module
ops
=
m
.
def_submodule
(
"ops"
,
"vLLM custom operators"
);
// Attention ops
ops
.
def
(
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached keys/values using PagedAttention."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
ops
.
def
(
"paged_attention_v1"
,
&
paged_attention_v1
,
"Compute the attention between an input query and the cached "
"keys/values using PagedAttention."
);
ops
.
def
(
"paged_attention_v2"
,
&
paged_attention_v2
,
"PagedAttention V2."
);
// Activation ops
ops
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU with `none` approximation."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Activation function used in GeGLU with `tanh` approximation."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
ops
.
def
(
"silu_and_mul"
,
&
silu_and_mul
,
"Activation function used in SwiGLU."
);
ops
.
def
(
"gelu_and_mul"
,
&
gelu_and_mul
,
"Activation function used in GeGLU with `none` approximation."
);
ops
.
def
(
"gelu_tanh_and_mul"
,
&
gelu_tanh_and_mul
,
"Activation function used in GeGLU with `tanh` approximation."
);
ops
.
def
(
"gelu_new"
,
&
gelu_new
,
"GELU implementation used in GPT-2."
);
ops
.
def
(
"gelu_fast"
,
&
gelu_fast
,
"Approximate GELU implementation."
);
// Layernorm
ops
.
def
(
"rms_norm"
,
&
rms_norm
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
ops
.
def
(
"rms_norm"
,
&
rms_norm
,
"Apply Root Mean Square (RMS) Normalization to the input tensor."
);
ops
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
"In-place fused Add and RMS Normalization"
);
ops
.
def
(
"fused_add_rms_norm"
,
&
fused_add_rms_norm
,
"In-place fused Add and RMS Normalization"
);
// Rotary embedding
ops
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
ops
.
def
(
"rotary_embedding"
,
&
rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key"
);
ops
.
def
(
"batched_rotary_embedding"
,
&
batched_rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key (supports multiple loras)"
);
ops
.
def
(
"batched_rotary_embedding"
,
&
batched_rotary_embedding
,
"Apply GPT-NeoX or GPT-J style rotary embedding to query and key "
"(supports multiple loras)"
);
// Quantization ops
#ifndef USE_ROCM
ops
.
def
(
"aqlm_gemm"
,
&
aqlm_gemm
,
"Quantized GEMM for AQLM"
);
ops
.
def
(
"aqlm_dequant"
,
&
aqlm_dequant
,
"Decompression method for AQLM"
);
ops
.
def
(
"awq_gemm"
,
&
awq_gemm
,
"Quantized GEMM for AWQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"marlin_gemm"
,
&
marlin_gemm
,
"Marlin (Dense) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_24_gemm"
,
&
gptq_marlin_24_gemm
,
"Marlin_24 (Sparse) Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_gemm"
,
&
gptq_marlin_gemm
,
"gptq_marlin Optimized Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_marlin_repack"
,
&
gptq_marlin_repack
,
"gptq_marlin repack from GPTQ"
);
ops
.
def
(
"awq_dequantize"
,
&
awq_dequantize
,
"Dequantization for AWQ"
);
ops
.
def
(
"cutlass_scaled_mm_dq"
,
&
cutlass_scaled_mm_dq
,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or per-row/column quantization."
);
ops
.
def
(
"cutlass_scaled_mm_dq"
,
&
cutlass_scaled_mm_dq
,
"CUTLASS w8a8 GEMM, supporting symmetric per-tensor or "
"per-row/column quantization."
);
#endif
ops
.
def
(
"gptq_gemm"
,
&
gptq_gemm
,
"Quantized GEMM for GPTQ"
);
ops
.
def
(
"gptq_shuffle"
,
&
gptq_shuffle
,
"Post processing for GPTQ"
);
ops
.
def
(
"squeezellm_gemm"
,
&
squeezellm_gemm
,
"Quantized GEMM for SqueezeLLM"
);
ops
.
def
(
"static_scaled_fp8_quant"
,
&
static_scaled_fp8_quant
,
"Compute FP8 quantized tensor for given scaling factor"
);
ops
.
def
(
"dynamic_scaled_fp8_quant"
,
&
dynamic_scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size."
);
ops
.
def
(
"static_scaled_fp8_quant"
,
&
static_scaled_fp8_quant
,
"Compute FP8 quantized tensor for given scaling factor"
);
ops
.
def
(
"dynamic_scaled_fp8_quant"
,
&
dynamic_scaled_fp8_quant
,
"Compute FP8 quantized tensor and scaling factor"
);
ops
.
def
(
"moe_align_block_size"
,
&
moe_align_block_size
,
"Aligning the number of tokens to be processed by each expert such "
"that it is divisible by the block size."
);
// Cache ops
pybind11
::
module
cache_ops
=
m
.
def_submodule
(
"cache_ops"
,
"vLLM cache ops"
);
cache_ops
.
def
(
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
cache_ops
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
"Convert the key and value cache to fp8 data type"
);
cache_ops
.
def
(
"swap_blocks"
,
&
swap_blocks
,
"Swap in (out) the cache blocks from src to dst"
);
cache_ops
.
def
(
"copy_blocks"
,
&
copy_blocks
,
"Copy the cache blocks from src to dst"
);
cache_ops
.
def
(
"reshape_and_cache"
,
&
reshape_and_cache
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"reshape_and_cache_flash"
,
&
reshape_and_cache_flash
,
"Reshape the key and value tensors and cache them"
);
cache_ops
.
def
(
"convert_fp8"
,
&
convert_fp8
,
"Convert the key and value cache to fp8 data type"
);
// Cuda utils
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
cuda_utils
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
pybind11
::
module
cuda_utils
=
m
.
def_submodule
(
"cuda_utils"
,
"vLLM cuda utils"
);
cuda_utils
.
def
(
"get_device_attribute"
,
&
get_device_attribute
,
"Gets the specified device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
cuda_utils
.
def
(
"get_max_shared_memory_per_block_device_attribute"
,
&
get_max_shared_memory_per_block_device_attribute
,
"Gets the maximum shared memory per block device attribute."
);
#ifndef USE_ROCM
// Custom all-reduce kernels
...
...
@@ -134,5 +105,4 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
custom_ar
.
def
(
"register_graph_buffers"
,
&
register_graph_buffers
,
"register_graph_buffers"
);
#endif
}
csrc/quantization/aqlm/gemm_kernels.cu
View file @
5f6d10c1
...
...
@@ -25,32 +25,28 @@
#include <iostream>
#include <cstdlib>
namespace
vllm
{
namespace
aqlm
{
__global__
void
Code1x16MatVec
(
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
const
int
prob_m
,
const
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
const
int
codebook_stride
// as int4.
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
const
int
prob_m
,
const
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long.
const
int
codebook_stride
// as int4.
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
}
...
...
@@ -67,8 +63,7 @@ __global__ void Code1x16MatVec(
// We pad shared memory to avoid bank conflicts during reads
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
8
;
i
+=
blockDim
.
x
)
{
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
}
__syncthreads
();
b_gl_rd
+=
32
*
8
;
...
...
@@ -76,22 +71,19 @@ __global__ void Code1x16MatVec(
int
b_sh_rd
=
9
*
(
threadIdx
.
x
%
32
);
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint16_t
*
enc
=
reinterpret_cast
<
const
uint16_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
uint32_t
dec
[
4
];
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
dec
[
0
]),
"=r"
(
dec
[
1
]),
"=r"
(
dec
[
2
]),
"=r"
(
dec
[
3
])
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]])
);
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
dec
[
0
]),
"=r"
(
dec
[
1
]),
"=r"
(
dec
[
2
]),
"=r"
(
dec
[
3
])
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]]));
half2
*
a
=
reinterpret_cast
<
half2
*>
(
&
dec
);
half2
*
b
=
reinterpret_cast
<
half2
*>
(
&
sh_b
[
b_sh_rd
]);
half2
res2
=
{};
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
res2
=
__hfma2
(
a
[
j
],
b
[
j
],
res2
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
res2
=
__hfma2
(
a
[
j
],
b
[
j
],
res2
);
res
+=
__half2float
(
res2
.
x
)
+
__half2float
(
res2
.
y
);
b_sh_rd
++
;
}
...
...
@@ -100,37 +92,33 @@ __global__ void Code1x16MatVec(
}
if
(
pred
)
{
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
if
(
threadIdx
.
x
%
32
==
0
)
reinterpret_cast
<
__half
*>
(
C
)[
c_gl_wr
]
=
__float2half
(
res
);
}
}
__global__
void
Code2x8MatVec
(
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
const
int
codebook_stride
// as int4.
const
int4
*
__restrict__
A
,
const
int4
*
__restrict__
B
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long.
const
int
codebook_stride
// as int4.
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
}
...
...
@@ -148,9 +136,8 @@ __global__ void Code2x8MatVec(
for
(
int
i
=
threadIdx
.
x
;
i
<
2
*
256
;
i
+=
blockDim
.
x
)
{
int4
dec
=
codebook
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
}
__syncthreads
();
...
...
@@ -161,8 +148,7 @@ __global__ void Code2x8MatVec(
// We pad shared memory to avoid bank conflicts during reads
__syncthreads
();
for
(
int
i
=
threadIdx
.
x
;
i
<
32
*
8
;
i
+=
blockDim
.
x
)
{
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
if
(
b_gl_rd
+
i
<
prob_k
/
8
)
sh_b
[
9
*
(
i
/
8
)
+
i
%
8
]
=
B
[
b_gl_rd
+
i
];
}
__syncthreads
();
b_gl_rd
+=
32
*
8
;
...
...
@@ -170,13 +156,15 @@ __global__ void Code2x8MatVec(
int
b_sh_rd
=
9
*
(
threadIdx
.
x
%
32
);
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint8_t
*
enc
=
reinterpret_cast
<
const
uint8_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
half2
*
b
=
reinterpret_cast
<
half2
*>
(
&
sh_b
[
b_sh_rd
]);
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
half2
*
b
=
reinterpret_cast
<
half2
*>
(
&
sh_b
[
b_sh_rd
]);
half2
res2
=
{};
#pragma unroll
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
res2
=
__hfma2
(
__hadd2
(
a0
[
j
],
a1
[
j
]),
b
[
j
],
res2
);
res
+=
__half2float
(
res2
.
x
)
+
__half2float
(
res2
.
y
);
...
...
@@ -187,36 +175,31 @@ __global__ void Code2x8MatVec(
}
if
(
pred
)
{
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
#pragma unroll
for
(
int
i
=
16
;
i
>
0
;
i
/=
2
)
res
+=
__shfl_down_sync
(
0xffffffff
,
res
,
i
);
if
(
threadIdx
.
x
%
32
==
0
)
reinterpret_cast
<
__half
*>
(
C
)[
c_gl_wr
]
=
__float2half
(
res
);
}
}
__global__
void
Code1x16Dequant
(
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long, sums to m.
const
int
codebook_stride
// as int4
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long, sums to m.
const
int
codebook_stride
// as int4
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
}
...
...
@@ -231,17 +214,15 @@ __global__ void Code1x16Dequant(
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint16_t
*
enc
=
reinterpret_cast
<
const
uint16_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int4
chunk
;
auto
dec
=
reinterpret_cast
<
uint32_t
*>
(
&
chunk
);
// We bypass the L1 cache to avoid massive amounts of memory streaming that doesn't
// actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
dec
[
0
]),
"=r"
(
dec
[
1
]),
"=r"
(
dec
[
2
]),
"=r"
(
dec
[
3
])
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]])
);
// We bypass the L1 cache to avoid massive amounts of memory streaming
// that doesn't actually help us; this brings > 2x speedup.
asm
volatile
(
"ld.cg.global.v4.u32 {%0, %1, %2, %3}, [%4];"
:
"=r"
(
dec
[
0
]),
"=r"
(
dec
[
1
]),
"=r"
(
dec
[
2
]),
"=r"
(
dec
[
3
])
:
"l"
((
void
*
)
&
codebook
[
enc
[
i
]]));
C
[
a_gl_rd
*
8
+
i
]
=
chunk
;
}
...
...
@@ -250,28 +231,25 @@ __global__ void Code1x16Dequant(
}
}
__global__
void
Code2x8Dequant
(
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const
int
codebook_stride
// as int4
const
int4
*
__restrict__
A
,
int4
*
__restrict__
C
,
const
int4
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const
int
codebook_stride
// as int4
)
{
int
a_gl_stride
=
prob_k
/
8
/
8
;
int
a_gl_rd
=
(
blockDim
.
x
/
32
)
*
blockIdx
.
x
+
(
threadIdx
.
x
/
32
);
bool
pred
=
a_gl_rd
<
prob_m
;
if
(
pred
)
{
//
advance to the correct codebook, this easy because we only multiply one
column of the codebook.
if
(
pred
)
{
// advance to the correct codebook, this easy because we only multiply one
// column of the codebook.
auto
codebook_size
=
&
codebook_a_sizes
.
x
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
while
(
a_gl_rd
>=
*
codebook_size
)
{
codebook
+=
codebook_stride
;
++
codebook_size
;
}
}
...
...
@@ -290,9 +268,8 @@ __global__ void Code2x8Dequant(
for
(
int
i
=
threadIdx
.
x
;
i
<
2
*
256
;
i
+=
blockDim
.
x
)
{
int4
dec
=
codebook
[
i
];
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
#pragma unroll
for
(
int
j
=
0
;
j
<
8
;
j
++
)
sh_code
[
8
*
i
+
(
j
+
lane
)
%
8
]
=
dec
;
}
__syncthreads
();
...
...
@@ -302,12 +279,14 @@ __global__ void Code2x8Dequant(
while
(
iters
--
)
{
if
(
pred
&&
a_gl_rd
<
a_gl_end
)
{
const
uint8_t
*
enc
=
reinterpret_cast
<
const
uint8_t
*>
(
&
A
[
a_gl_rd
]);
#pragma unroll
#pragma unroll
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
int4
chunk
;
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
#pragma unroll
half2
*
a0
=
reinterpret_cast
<
half2
*>
(
&
sh_code0
[
8
*
enc
[
2
*
i
+
0
]
+
lane
]);
half2
*
a1
=
reinterpret_cast
<
half2
*>
(
&
sh_code1
[
8
*
enc
[
2
*
i
+
1
]
+
lane
]);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
j
++
)
reinterpret_cast
<
half2
*>
(
&
chunk
)[
j
]
=
__hadd2
(
a0
[
j
],
a1
[
j
]);
C
[
a_gl_rd
*
8
+
i
]
=
chunk
;
...
...
@@ -317,22 +296,15 @@ __global__ void Code2x8Dequant(
}
}
inline
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
inline
int
ceildiv
(
int
a
,
int
b
)
{
return
(
a
+
b
-
1
)
/
b
;
}
const
int
THREAD_M
=
16
;
void
code1x16_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
void
code1x16_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
0
);
int
waves
=
0
;
...
...
@@ -345,28 +317,16 @@ void code1x16_matvec_cuda(
int
blocks
=
ceildiv
(
prob_m
,
thread_m
);
int
threads
=
32
*
thread_m
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Code1x16MatVec
<<<
blocks
,
threads
,
16
*
32
*
9
,
stream
>>>
(
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
Code1x16MatVec
<<<
blocks
,
threads
,
16
*
32
*
9
,
stream
>>>
(
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
}
void
code2x8_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
void
code2x8_matvec_cuda
(
const
void
*
__restrict__
A
,
const
void
*
__restrict__
B
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
const
int
codebook_stride
)
{
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
0
);
int
waves
=
0
;
...
...
@@ -379,30 +339,20 @@ void code2x8_matvec_cuda(
int
blocks
=
ceildiv
(
prob_m
,
thread_m
);
int
threads
=
32
*
thread_m
;
int
shared
=
16
*
(
2
*
256
*
8
+
32
*
9
);
cudaFuncSetAttribute
(
Code2x8MatVec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
cudaFuncSetAttribute
(
Code2x8MatVec
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Code2x8MatVec
<<<
blocks
,
threads
,
shared
,
stream
>>>
(
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
(
const
int4
*
)
A
,
(
const
int4
*
)
B
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
}
void
code1x16_dequant_cuda
(
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
const
int
codebook_stride
// as int4.
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each
// codebook, at most 3 long.
const
int
codebook_stride
// as int4.
)
{
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
0
);
...
...
@@ -417,25 +367,21 @@ void code1x16_dequant_cuda(
int
threads
=
32
*
thread_m
;
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
Code1x16Dequant
<<<
blocks
,
threads
,
0
,
stream
>>>
(
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long.
codebook_stride
// as int4.
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at
// most 3 long.
codebook_stride
// as int4.
);
}
// Dequantizes the code and codebook into weights.
void
code2x8_dequant_cuda
(
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at most 3 long, corresponds to cols.
const
int
codebook_stride
// as int4
void
code2x8_dequant_cuda
(
const
void
*
__restrict__
A
,
void
*
__restrict__
C
,
const
void
*
__restrict__
codebook
,
int
prob_m
,
int
prob_k
,
const
int4
codebook_a_sizes
,
// cumulative sizes of A spanning each codebook, at
// most 3 long, corresponds to cols.
const
int
codebook_stride
// as int4
)
{
int
sms
;
cudaDeviceGetAttribute
(
&
sms
,
cudaDevAttrMultiProcessorCount
,
0
);
...
...
@@ -451,74 +397,50 @@ void code2x8_dequant_cuda(
int
shared
=
16
*
(
2
*
256
*
8
+
32
*
9
);
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
cudaFuncSetAttribute
(
Code2x8Dequant
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
cudaFuncSetAttribute
(
Code2x8Dequant
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
shared
);
Code2x8Dequant
<<<
blocks
,
threads
,
shared
,
stream
>>>
(
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
(
const
int4
*
)
A
,
(
int4
*
)
C
,
(
const
int4
*
)
codebook
,
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
);
}
int
codebook_stride
(
const
torch
::
Tensor
&
codebooks
)
{
int
codebook_stride
(
const
torch
::
Tensor
&
codebooks
)
{
return
codebooks
.
stride
(
0
)
*
codebooks
.
element_size
()
/
sizeof
(
int4
);
}
void
code1x16_matvec
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
// cumulative sizes of A spanning each codebook, at most 3 long.
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
// cumulative sizes of A spanning each
// codebook, at most 3 long.
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
int
prob_m
=
C
.
size
(
0
);
int
prob_k
=
B
.
size
(
0
);
code1x16_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
(
codebook
)
);
code1x16_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
codebook_stride
(
codebook
));
}
torch
::
Tensor
code1x16_matmat
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
int4
codebook_a_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
torch
::
Tensor
code1x16_matmat
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
int4
codebook_a_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
input_sizes
=
input
.
sizes
();
auto
out_features
=
codes
.
size
(
0
)
*
codebooks
.
size
(
2
);
auto
flat_input
=
input
.
reshape
({
-
1
,
input
.
size
(
-
1
)});
auto
flat_output
=
torch
::
empty
({
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
())
);
auto
flat_output
=
torch
::
empty
(
{
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
().
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
for
(
int
i
=
0
;
i
<
flat_input
.
size
(
0
);
++
i
)
{
auto
input_vec
=
flat_input
.
index
({
i
});
auto
output_vec
=
flat_output
.
index
({
i
});
code1x16_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
code1x16_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
}
flat_output
*=
scales
.
flatten
().
unsqueeze
(
0
);
...
...
@@ -533,55 +455,35 @@ torch::Tensor code1x16_matmat(
return
output
;
}
void
code2x8_matvec
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
)
{
void
code2x8_matvec
(
const
torch
::
Tensor
&
A
,
const
torch
::
Tensor
&
B
,
torch
::
Tensor
&
C
,
const
torch
::
Tensor
&
codebook
,
const
int4
codebook_a_sizes
)
{
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
A
));
int
prob_m
=
C
.
size
(
0
);
int
prob_k
=
B
.
size
(
0
);
code2x8_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
2
*
codebook_stride
(
codebook
)
);
code2x8_matvec_cuda
(
A
.
data_ptr
(),
B
.
data_ptr
(),
C
.
data_ptr
(),
codebook
.
data_ptr
(),
prob_m
,
prob_k
,
codebook_a_sizes
,
2
*
codebook_stride
(
codebook
));
}
torch
::
Tensor
code2x8_matmat
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
int4
codebook_a_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
torch
::
Tensor
code2x8_matmat
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
int4
codebook_a_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
auto
input_sizes
=
input
.
sizes
();
auto
out_features
=
codes
.
size
(
0
)
*
codebooks
.
size
(
2
);
auto
flat_input
=
input
.
reshape
({
-
1
,
input
.
size
(
-
1
)});
auto
flat_output
=
torch
::
empty
({
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
()
.
dtype
(
input
.
dtype
())
.
device
(
input
.
device
())
);
auto
flat_output
=
torch
::
empty
(
{
flat_input
.
size
(
0
),
out_features
},
torch
::
TensorOptions
().
dtype
(
input
.
dtype
()).
device
(
input
.
device
()));
for
(
int
i
=
0
;
i
<
flat_input
.
size
(
0
);
++
i
)
{
auto
input_vec
=
flat_input
.
index
({
i
});
auto
output_vec
=
flat_output
.
index
({
i
});
code2x8_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
code2x8_matvec
(
codes
.
squeeze
(
2
),
input_vec
,
output_vec
,
codebooks
,
codebook_a_sizes
);
}
flat_output
*=
scales
.
flatten
().
unsqueeze
(
0
);
if
(
bias
.
has_value
())
{
...
...
@@ -596,64 +498,56 @@ torch::Tensor code2x8_matmat(
}
// Accumulate the partition sizes.
int4
accumulate_sizes
(
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
accumulate_sizes
(
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
;
auto
cumulative_size
=
&
cumulative_sizes
.
x
;
int
i
=
0
;
int
last
=
0
;
assert
(
codebook_partition_sizes
.
size
(
0
)
<=
4
);
for
(;
i
<
codebook_partition_sizes
.
size
(
0
);
++
i
,
++
cumulative_size
)
{
for
(;
i
<
codebook_partition_sizes
.
size
(
0
);
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
codebook_partition_sizes
[
i
].
item
<
int
>
()
+
last
;
last
=
*
cumulative_size
;
}
// fill in the rest with unreachable.
for
(;
i
<
4
;
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
last
*
10
;
for
(;
i
<
4
;
++
i
,
++
cumulative_size
)
{
*
cumulative_size
=
last
*
10
;
}
return
cumulative_sizes
;
}
}
// namespace aqlm
}
// namespace vllm
}
// namespace aqlm
}
// namespace vllm
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
torch
::
Tensor
aqlm_gemm
(
const
torch
::
Tensor
&
input
,
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
scales
,
const
torch
::
Tensor
&
codebook_partition_sizes
,
const
std
::
optional
<
torch
::
Tensor
>&
bias
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
entries
=
codebooks
.
size
(
1
);
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
return
vllm
::
aqlm
::
code1x16_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
return
vllm
::
aqlm
::
code1x16_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
}
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
return
vllm
::
aqlm
::
code2x8_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
return
vllm
::
aqlm
::
code2x8_matmat
(
input
,
codes
,
codebooks
,
scales
,
cumulative_sizes
,
bias
);
}
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
return
{};
}
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
torch
::
Tensor
aqlm_dequant
(
const
torch
::
Tensor
&
codes
,
const
torch
::
Tensor
&
codebooks
,
const
torch
::
Tensor
&
codebook_partition_sizes
)
{
int4
cumulative_sizes
=
vllm
::
aqlm
::
accumulate_sizes
(
codebook_partition_sizes
);
int
const
nbooks
=
codebooks
.
size
(
0
)
/
codebook_partition_sizes
.
size
(
0
);
int
const
entries
=
codebooks
.
size
(
1
);
...
...
@@ -668,45 +562,37 @@ torch::Tensor aqlm_dequant(
assert
(
out_features
=
codebook_partition_sizes
.
sum
().
item
<
int
>
());
auto
weights
=
torch
::
empty
({
out_features
,
in_features
},
torch
::
TensorOptions
()
.
dtype
(
codebooks
.
dtype
())
.
device
(
codebooks
.
device
())
);
torch
::
TensorOptions
()
.
dtype
(
codebooks
.
dtype
())
.
device
(
codebooks
.
device
()));
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
vllm
::
aqlm
::
code1x16_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
vllm
::
aqlm
::
codebook_stride
(
codebooks
));
if
(
nbooks
==
1
&&
entries
==
(
1
<<
16
))
{
vllm
::
aqlm
::
code1x16_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
vllm
::
aqlm
::
codebook_stride
(
codebooks
));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation.)
// weights *= scales.index({"...", 0, 0});
return
weights
;
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation.) weights *=
// scales.index({"...", 0, 0});
return
weights
;
}
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
vllm
::
aqlm
::
code2x8_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
vllm
::
aqlm
::
codebook_stride
(
codebooks
));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower and not consistent with gemv implementation)
// weights *= scales.index({"...", 0, 0});
return
weights
;
if
(
nbooks
==
2
&&
entries
==
(
1
<<
8
))
{
vllm
::
aqlm
::
code2x8_dequant_cuda
(
codes
.
data_ptr
(),
weights
.
data_ptr
(),
codebooks
.
data_ptr
(),
out_features
,
in_features
,
cumulative_sizes
,
vllm
::
aqlm
::
codebook_stride
(
codebooks
));
// if you wanted to flip to scaling the weights, (though it's 30%-ish slower
// and not consistent with gemv implementation) weights *=
// scales.index({"...", 0, 0});
return
weights
;
}
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
TORCH_CHECK
(
false
,
"AQLM with "
,
nbooks
,
" codebooks and "
,
entries
,
" entries is not currently supported."
)
return
{};
}
csrc/quantization/awq/dequantize.cuh
View file @
5f6d10c1
/*
Adapted from https://github.com/mit-han-lab/llm-awq
Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
Modified from NVIDIA FasterTransformer:
https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/cutlass_extensions/include/cutlass_extensions/interleaved_numeric_conversion.h
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
...
...
@@ -14,74 +14,88 @@ Modified from NVIDIA FasterTransformer: https://github.com/NVIDIA/FasterTransfor
namespace
vllm
{
namespace
awq
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
__device__
uint4
dequantize_s4_to_fp16x2
(
uint32_t
const
&
source
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
assert
(
false
);
#else
uint4
result
;
uint4
result
;
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i4s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
uint32_t
*
h
=
reinterpret_cast
<
uint32_t
*>
(
&
result
);
uint32_t
const
i4s
=
reinterpret_cast
<
uint32_t
const
&>
(
source
);
// First, we extract the i4s and construct an intermediate fp16 number.
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
BOTTOM_MASK
=
0x000f000f
;
static
constexpr
uint32_t
TOP_MASK
=
0x00f000f0
;
static
constexpr
uint32_t
I4s_TO_F16s_MAGIC_NUM
=
0x64006400
;
// First, we extract the i4s and construct an intermediate fp16 number.
static
constexpr
uint32_t
immLut
=
(
0xf0
&
0xcc
)
|
0xaa
;
static
constexpr
uint32_t
BOTTOM_MASK
=
0x000f000f
;
static
constexpr
uint32_t
TOP_MASK
=
0x00f000f0
;
static
constexpr
uint32_t
I4s_TO_F16s_MAGIC_NUM
=
0x64006400
;
// Note that the entire sequence only requires 1 shift instruction. This is thanks to the register packing
// format and the fact that we force our integers to be unsigned, and account for this in the fp16 subtractions.
// In addition, I exploit the fact that sub and fma have the same throughput in order to convert elt_23 and
// elt_67 to fp16 without having to shift them to the bottom bits before hand.
// Note that the entire sequence only requires 1 shift instruction. This is
// thanks to the register packing format and the fact that we force our
// integers to be unsigned, and account for this in the fp16 subtractions. In
// addition, I exploit the fact that sub and fma have the same throughput in
// order to convert elt_23 and elt_67 to fp16 without having to shift them to
// the bottom bits before hand.
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW dependency if we issue
// immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Shift right by 8 to now consider elt_45 and elt_67. Issue first to hide RAW
// dependency if we issue immediately before required.
const
uint32_t
top_i4s
=
i4s
>>
8
;
// Extract elt_01 - (i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_23 (i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_45 (top_i4s & 0x000f000f) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
top_i4s
),
"n"
(
BOTTOM_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// Extract elt_67 (top_i4s & 0x00f000f0) | 0x64006400
asm
volatile
(
"lop3.b32 %0, %1, %2, %3, %4;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
top_i4s
),
"n"
(
TOP_MASK
),
"n"
(
I4s_TO_F16s_MAGIC_NUM
),
"n"
(
immLut
));
// I use inline PTX below because I am not sure if the compiler will emit float2half instructions if I use the
// half2 ctor. In this case, I chose performance reliability over code readability.
// I use inline PTX below because I am not sure if the compiler will emit
// float2half instructions if I use the half2 ctor. In this case, I chose
// performance reliability over code readability.
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static
constexpr
uint32_t
FP16_TOP_MAGIC_NUM
=
0x64006400
;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static
constexpr
uint32_t
ONE_SIXTEENTH
=
0x2c002c00
;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static
constexpr
uint32_t
NEG_64
=
0xd400d400
;
// This is the half2 {1032, 1032} represented as an integer.
// static constexpr uint32_t FP16_TOP_MAGIC_NUM = 0x64086408;
// Haotian: subtract {1024, 1024} instead, we do not need to map to [-8, 7]
static
constexpr
uint32_t
FP16_TOP_MAGIC_NUM
=
0x64006400
;
// This is the half2 {1 / 16, 1 / 16} represented as an integer.
static
constexpr
uint32_t
ONE_SIXTEENTH
=
0x2c002c00
;
// This is the half2 {-72, -72} represented as an integer.
// static constexpr uint32_t NEG_72 = 0xd480d480;
// Haotian: Let's use {-64, -64}.
static
constexpr
uint32_t
NEG_64
=
0xd400d400
;
// Finally, we construct the output numbers.
// Convert elt_01
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_23
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
// Convert elt_45
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_67
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
// Finally, we construct the output numbers.
// Convert elt_01
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
0
])
:
"r"
(
h
[
0
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_23
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
1
])
:
"r"
(
h
[
1
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
// Convert elt_45
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
h
[
2
])
:
"r"
(
h
[
2
]),
"r"
(
FP16_TOP_MAGIC_NUM
));
// Convert elt_67
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
h
[
3
])
:
"r"
(
h
[
3
]),
"r"
(
ONE_SIXTEENTH
),
"r"
(
NEG_64
));
return
result
;
return
result
;
#endif
}
}
// namespace awq
}
// namespace vllm
}
// namespace awq
}
// namespace vllm
csrc/quantization/awq/gemm_kernels.cu
View file @
5f6d10c1
/*
Adapted from https://github.com/mit-han-lab/llm-awq
@article{lin2023awq,
title={AWQ: Activation-aware Weight Quantization for LLM Compression and Acceleration},
author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang, Shang and Dang, Xingyu and Han, Song},
journal={arXiv},
year={2023}
title={AWQ: Activation-aware Weight Quantization for LLM Compression and
Acceleration}, author={Lin, Ji and Tang, Jiaming and Tang, Haotian and Yang,
Shang and Dang, Xingyu and Han, Song}, journal={arXiv}, year={2023}
}
*/
#include <torch/extension.h>
#include <c10/cuda/CUDAGuard.h>
...
...
@@ -20,26 +18,20 @@ namespace vllm {
namespace
awq
{
// Pack two half values.
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
static
inline
__device__
__host__
unsigned
__pack_half2
(
const
half
x
,
const
half
y
)
{
unsigned
v0
=
*
((
unsigned
short
*
)
&
x
);
unsigned
v1
=
*
((
unsigned
short
*
)
&
y
);
return
(
v1
<<
16
)
|
v0
;
}
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
template
<
int
N
>
__global__
void
__launch_bounds__
(
64
)
gemm_forward_4bit_cuda_m16nXk32
(
int
G
,
int
split_k_iters
,
half
*
__restrict__
A
,
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
int
M
,
int
IC
,
int
OC
,
half
*
__restrict__
C
)
{
// Only support matrix n = 64 or 128
assert
(
N
==
64
||
N
==
128
);
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ < 750
...
...
@@ -70,43 +62,46 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
static
constexpr
int
row_stride
=
2
*
32
*
8
/
N
;
bool
ld_zero_flag
=
(
threadIdx
.
y
*
32
+
threadIdx
.
x
)
*
8
<
N
;
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
bool
ld_A_flag
=
(
blockIdx_y
/
j_factors1
*
16
+
threadIdx
.
y
*
row_stride_warp
+
threadIdx
.
x
*
8
/
32
)
<
M
;
// threadIdx.y is warp_id
// bool wb_C_flag = (threadIdx.x / 4) < M;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
(
256
/
N
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
)
)
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
N
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
half
*
C_ptr
=
C
+
static_cast
<
long
long
>
(
blockIdx_z
)
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
((
int
)
threadIdx
.
y
)
*
(
N
/
2
)
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
half
*
A_ptr
=
A
+
(((
int
)
blockIdx_y
)
/
j_factors1
*
16
+
(((
int
)
threadIdx
.
y
)
*
row_stride_warp
)
+
((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
IC
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
int
*
B_ptr
=
B
+
((
int
)
threadIdx
.
y
)
*
(
OC
/
8
)
*
(
256
/
N
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
OC
/
8
)
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
1
;
// Why * 1 in the above line?
half
*
A_shared_ptr
=
A_shared
+
((
int
)
threadIdx
.
y
)
*
row_stride_warp
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
32
/
8
))
*
(
32
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
32
/
8
))
*
8
;
half
*
B_shared_ptr
=
B_shared
+
((
int
)
threadIdx
.
y
)
*
(
row_stride
/
2
)
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
/
(
N
/
8
))
*
(
N
+
8
)
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
int
*
zeros_ptr
=
zeros
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
(
N
/
8
)
+
((
int
)
threadIdx
.
x
)
%
(
N
/
8
);
half
*
scaling_factors_ptr
=
scaling_factors
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
(((
int
)
threadIdx
.
x
)
%
(
N
/
8
))
*
8
;
half
*
C_ptr
=
C
+
static_cast
<
long
long
>
(
blockIdx_z
)
*
M
*
OC
// blockIdz.x -> split_k dim
+
(((
int
)
blockIdx_y
)
%
j_factors1
)
*
N
+
((
int
)
threadIdx
.
y
)
*
(
N
/
2
)
+
(((
int
)
threadIdx
.
x
)
%
4
)
*
2
;
// preload s.f. and zeros
int
k_bound
=
(
IC
/
32
+
split_k_iters
-
1
)
/
split_k_iters
;
...
...
@@ -115,57 +110,83 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
int
k_0_0
=
_k_0_0
*
split_k_iters
+
blockIdx_z
;
__syncthreads
();
// TODO: Haotian: blockIdx_y / j_factors1 in A loading to support bsz > 16
if
(
ld_A_flag
)
{
if
(
ld_A_flag
)
{
*
(
uint4
*
)(
A_shared_ptr
)
=
*
(
uint4
*
)(
A_ptr
+
(
k_0_0
*
32
));
}
else
{
}
else
{
*
(
uint4
*
)(
A_shared_ptr
)
=
make_uint4
(
0
,
0
,
0
,
0
);
}
// for (int ax0_ax1_fused_0 = 0; ax0_ax1_fused_0 < 2; ++ax0_ax1_fused_0) {
uint32_t
zeros_loaded
=
*
(
uint32_t
*
)(
zeros_ptr
+
k_0_0
*
32
/
G
*
(
OC
/
8
));
uint4
B_loaded_zero
=
dequantize_s4_to_fp16x2
(
zeros_loaded
);
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
uint4
B_loaded_scale
=
*
(
uint4
*
)(
scaling_factors_ptr
+
k_0_0
*
32
/
G
*
(
OC
));
/*
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 && threadIdx.y == 0){
printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x, B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x, B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
if (blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 0 &&
threadIdx.y == 0){ printf("%x %x %x %x %x %x %x %x\n", B_loaded_scale.x,
B_loaded_scale.y, B_loaded_scale.z, B_loaded_scale.w, B_loaded_zero.x,
B_loaded_zero.y, B_loaded_zero.z, B_loaded_zero.w);
}
*/
// uint4 B_loaded_scale = make_uint4(0, 0, 0, 0);
int
*
B_ptr_local
=
B_ptr
+
k_0_0
*
32
*
(
OC
/
8
);
for
(
int
ax0_ax1_fused_0
=
0
;
ax0_ax1_fused_0
<
N
/
16
;
++
ax0_ax1_fused_0
)
{
// B: 32 x 136 (128+8) float16
// each warp: 32 x 4
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) * 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15) * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 * 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) * 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) * 8)));
// row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
// each thr: read 32 bit -> convert to 8xFP16 (a UINT4) -> scale and minus
// zero -> WB UINT4
// *(uint4*)(B_shared + ((((ax0_ax1_fused_0 * 544) + (((int)threadIdx.y) *
// 272)) + ((((int)threadIdx.x) >> 4) * 136)) + ((((int)threadIdx.x) & 15)
// * 8))) = *(uint4*)(B + ((((((k_0_0 * 163840) + (ax0_ax1_fused_0 *
// 20480)) + (((int)threadIdx.y) * 10240)) + ((((int)threadIdx.x) >> 4) *
// 5120)) + (((int)blockIdx_y) * 128)) + ((((int)threadIdx.x) & 15) *
// 8))); row stride in shared memory: (NWARPS * 32 * 8 / cta_N)
uint32_t
B_loaded
=
*
(
uint32_t
*
)(
B_ptr_local
+
ax0_ax1_fused_0
*
row_stride
*
(
OC
/
8
));
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
//uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_zero = *(uint4*)(zeros_shared + (threadIdx.x % (cta_N /
// 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x % (cta_N / 8)) * 8);
// uint4 B_loaded_scale = *(uint4*)(scaling_factors_shared + (threadIdx.x
// % (cta_N / 8)) * 8);
// - zero and * scale
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq = q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
// TODO (Haotian): can save 4 assembly instructions if sormulate as deq =
// q * scale - zero * scale.
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
/*
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 == 0 && threadIdx.x == 17 && threadIdx.y == 0){
printf("[x] %X %X %X %X\n", B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
if (ax0_ax1_fused_0 == 0 && blockIdx_z == 0 && blockIdx_y == 0 && k_0_0 ==
0 && threadIdx.x == 17 && threadIdx.y == 0){ printf("[x] %X %X %X %X\n",
B_loaded_fp16.x, B_loaded_fp16.y, B_loaded_fp16.z, B_loaded_fp16.w);
}
*/
// write back
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
N
+
8
))
=
B_loaded_fp16
;
*
(
uint4
*
)(
B_shared_ptr
+
ax0_ax1_fused_0
*
row_stride
*
(
N
+
8
))
=
B_loaded_fp16
;
}
__syncthreads
();
...
...
@@ -173,112 +194,179 @@ __global__ void __launch_bounds__(64) gemm_forward_4bit_cuda_m16nXk32(
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
A_shared
[(
k_0_1
*
16
)]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
40
)
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
)))));
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
)
);
"ldmatrix.sync.aligned.m8n8.x4.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"=r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
])
:
"r"
(
addr
));
}
for
(
int
ax1_0
=
0
;
ax1_0
<
N
/
32
;
++
ax1_0
)
{
{
unsigned
int
addr
;
__asm__
__volatile__
(
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
(
N
*
16
+
128
))
+
(((
int
)
threadIdx
.
y
)
*
(
N
/
2
)))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
(
N
+
8
))
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
))))
);
"{ .reg .u64 addr; cvta.to.shared.u64 addr, %1; cvt.u32.u64 %0, "
"addr; }
\n
"
:
"=r"
(
addr
)
:
"l"
((
void
*
)((
&
(
B_shared
[(((
k_0_1
*
(
N
*
16
+
128
))
+
(((
int
)
threadIdx
.
y
)
*
(
N
/
2
)))
+
(
ax1_0
*
16
))]))
+
(((((
int
)
threadIdx
.
x
)
&
15
)
*
(
N
+
8
))
+
((((
int
)
threadIdx
.
x
)
>>
4
)
*
8
)))));
__asm__
__volatile__
(
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
)
);
"ldmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
"{%0, %1, %2, %3}, [%4];
\n
"
:
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
0
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
1
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
2
]),
"=r"
(((
unsigned
*
)(
B_shared_warp
+
(
ax1_0
*
8
)))[
3
])
:
"r"
(
addr
));
}
}
for
(
int
j_0_4
=
0
;
j_0_4
<
N
/
32
;
++
j_0_4
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 750
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
"mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
#else
#else
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
"%13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
(
j_0_4
*
8
)))[
3
]));
}
{
__asm__
__volatile__
(
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
"mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32"
"{%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, "
"%13};
\n
"
:
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"=f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
])
:
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
0
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
1
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
2
]),
"r"
(((
unsigned
*
)(
A_shared_warp
+
0
))[
3
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"r"
(((
unsigned
*
)(
B_shared_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
0
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
1
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
2
]),
"f"
(((
float
*
)(
C_warp
+
((
j_0_4
*
8
)
+
4
)))[
3
]));
}
#endif
#endif
}
}
}
// TODO: Shang: Hoist loop invariance.
// TODO: Shang: Hoist loop invariance.
for
(
int
ax1_0_1
=
0
;
ax1_0_1
<
4
;
++
ax1_0_1
)
{
for
(
int
local_id
=
0
;
local_id
<
8
;
++
local_id
)
{
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
int
row_offset
=
(((
int
)
blockIdx_y
)
/
j_factors1
)
*
16
+
((
int
)
threadIdx
.
x
)
/
4
+
(
local_id
%
4
)
/
2
*
8
;
if
(
row_offset
<
M
)
{
*
(
C_ptr
+
ax1_0_1
*
16
+
row_offset
*
OC
+
(
local_id
/
4
)
*
8
+
local_id
%
2
)
=
__float2half
(
C_warp
[(
ax1_0_1
*
8
)
+
local_id
]);
}
}
}
#endif
}
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
__global__
void
__launch_bounds__
(
64
)
dequantize_weights
(
int
*
__restrict__
B
,
half
*
__restrict__
scaling_factors
,
int
*
__restrict__
zeros
,
half
*
__restrict__
C
,
int
G
)
{
int
j_factors1
=
4
;
int
row_stride2
=
4
;
int
split_k_iters
=
1
;
...
...
@@ -310,14 +398,30 @@ __global__ void __launch_bounds__(64) dequantize_weights(
uint32_t
B_loaded
=
*
(
uint32_t
*
)
B_ptr2
;
uint4
B_loaded_fp16
=
dequantize_s4_to_fp16x2
(
B_loaded
);
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_zero
.
x
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
x
)
:
"r"
(
B_loaded_fp16
.
x
),
"r"
(
B_loaded_scale
.
x
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_zero
.
y
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
y
)
:
"r"
(
B_loaded_fp16
.
y
),
"r"
(
B_loaded_scale
.
y
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_zero
.
z
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
z
)
:
"r"
(
B_loaded_fp16
.
z
),
"r"
(
B_loaded_scale
.
z
),
"r"
(
ZERO
));
asm
volatile
(
"sub.f16x2 %0, %1, %2;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_zero
.
w
));
asm
volatile
(
"fma.rn.f16x2 %0, %1, %2, %3;
\n
"
:
"=r"
(
B_loaded_fp16
.
w
)
:
"r"
(
B_loaded_fp16
.
w
),
"r"
(
B_loaded_scale
.
w
),
"r"
(
ZERO
));
*
(
uint4
*
)
B_shared_ptr2
=
B_loaded_fp16
;
...
...
@@ -326,58 +430,57 @@ __global__ void __launch_bounds__(64) dequantize_weights(
}
}
}
// namespace awq
}
// namespace vllm
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
)
{
int
in_c
=
_kernel
.
size
(
0
);
int
qout_c
=
_kernel
.
size
(
1
);
int
out_c
=
qout_c
*
8
;
int
G
=
in_c
/
_scaling_factors
.
size
(
0
);
int
x_thread
=
thx
;
int
y_thread
=
thy
;
int
x_blocks
=
1
;
int
y_blocks
=
1
;
if
(
thx
==
0
)
{
x_thread
=
qout_c
;
}
if
(
thy
==
0
)
{
y_thread
=
in_c
;
}
if
(
thx
==
0
&&
thy
==
0
)
{
x_thread
=
8
;
y_thread
=
8
;
x_blocks
=
(
int
)(
qout_c
/
8
);
y_blocks
=
(
int
)(
in_c
/
8
);
}
}
// namespace awq
}
// namespace vllm
torch
::
Tensor
awq_dequantize
(
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scaling_factors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
,
int
thx
,
int
thy
)
{
int
in_c
=
_kernel
.
size
(
0
);
int
qout_c
=
_kernel
.
size
(
1
);
int
out_c
=
qout_c
*
8
;
int
G
=
in_c
/
_scaling_factors
.
size
(
0
);
int
x_thread
=
thx
;
int
y_thread
=
thy
;
int
x_blocks
=
1
;
int
y_blocks
=
1
;
if
(
thx
==
0
)
{
x_thread
=
qout_c
;
}
if
(
thy
==
0
)
{
y_thread
=
in_c
;
}
if
(
thx
==
0
&&
thy
==
0
)
{
x_thread
=
8
;
y_thread
=
8
;
x_blocks
=
(
int
)(
qout_c
/
8
);
y_blocks
=
(
int
)(
in_c
/
8
);
}
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_scaling_factors
));
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_scaling_factors
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
_scaling_factors
.
dtype
()).
device
(
_scaling_factors
.
device
());
at
::
Tensor
_de_kernel
=
torch
::
empty
({
in_c
,
out_c
},
options
);
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
_scaling_factors
.
dtype
())
.
device
(
_scaling_factors
.
device
());
at
::
Tensor
_de_kernel
=
torch
::
empty
({
in_c
,
out_c
},
options
);
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
de_kernel
=
reinterpret_cast
<
half
*>
(
_de_kernel
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
de_kernel
=
reinterpret_cast
<
half
*>
(
_de_kernel
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_thread
,
y_thread
);
dim3
num_blocks
(
x_blocks
,
y_blocks
);
dim3
threads_per_block
(
x_thread
,
y_thread
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
vllm
::
awq
::
dequantize_weights
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
kernel
,
scaling_factors
,
zeros
,
de_kernel
,
G
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
vllm
::
awq
::
dequantize_weights
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
kernel
,
scaling_factors
,
zeros
,
de_kernel
,
G
);
return
_de_kernel
;
return
_de_kernel
;
}
// in_feats: M, IC [float16]
...
...
@@ -386,61 +489,61 @@ torch::Tensor awq_dequantize(
// zeros: IC // G, OC // 8 [int32] -> cast to IC // G, OC [uint4b]
// assume that batch_size < 16 for now
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_f
eat
s
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_scal
in
g
_f
actors
,
torch
::
Tensor
_zeros
,
int
split_k_iters
)
{
int
num_in_feats
=
_in_feats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device
_of
(
_in_feats
));
auto
options
=
torch
::
TensorOptions
().
dtype
(
_in_feats
.
dtype
()).
device
(
_in_feats
.
device
()
);
at
::
Tensor
_out_feats
=
torch
::
empty
({
split_k_iters
,
num_in_feats
,
_kernel
.
size
(
1
)
*
8
},
options
);
int
num_out_
feat
s
=
_out_feats
.
size
(
-
2
);
int
num_out_channels
=
_out_feats
.
size
(
-
1
);
auto
in_feats
=
reinterpret_cast
<
half
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
if
(
num_out_channels
%
64
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
if
(
group_size
%
32
!=
0
)
throw
std
::
invalid_argument
(
"Group size should be a multiple of 32"
);
if
(
num_out_channels
%
group_size
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
128
==
0
)
{
int
j_factors1
=
num_out_
channel
s
/
1
28
/
1
;
dim3
num_blocks
((
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.
x
:
32
//
thread
Idx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
128
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
int
j_factors1
=
num_out_channels
/
64
/
1
;
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
);
// threadIdx.
x
:
32
//
thread
Idx.y: i_factors[2] * j_factors[2]
dim3
threads_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
64
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
torch
::
Tensor
awq_gemm
(
torch
::
Tensor
_in_feats
,
torch
::
Tensor
_kernel
,
torch
::
Tensor
_
scal
in
g
_f
actors
,
torch
::
Tensor
_zero
s
,
int
split_k_iters
)
{
int
num_in_feats
=
_
in_f
eats
.
size
(
0
);
int
num_in_channels
=
_in_feats
.
size
(
1
);
const
at
::
cuda
::
OptionalCUDAGuard
device_guard
(
device_of
(
_in_feats
));
auto
options
=
torch
::
TensorOptions
()
.
dtype
(
_in_feats
.
dtype
())
.
device
(
_in_feats
.
device
(
));
at
::
Tensor
_out_feats
=
torch
::
empty
({
split_k_iters
,
num_in_feats
,
_kernel
.
size
(
1
)
*
8
},
options
);
int
num_out_feats
=
_out_feats
.
size
(
-
2
);
int
num_out_
channel
s
=
_out_feats
.
size
(
-
1
);
auto
in_feats
=
reinterpret_cast
<
half
*>
(
_in_feats
.
data_ptr
<
at
::
Half
>
());
auto
kernel
=
reinterpret_cast
<
int
*>
(
_kernel
.
data_ptr
<
int
>
());
auto
out_feats
=
reinterpret_cast
<
half
*>
(
_out_feats
.
data_ptr
<
at
::
Half
>
());
auto
scaling_factors
=
reinterpret_cast
<
half
*>
(
_scaling_factors
.
data_ptr
<
at
::
Half
>
());
auto
zeros
=
reinterpret_cast
<
int
*>
(
_zeros
.
data_ptr
<
int
>
());
int
group_size
=
num_in_channels
/
_scaling_factors
.
size
(
0
);
if
(
num_out_channels
%
64
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of cta_N = 64"
);
if
(
num_out_channels
%
8
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of pack_num = 8"
);
if
(
group_size
%
32
!=
0
)
throw
std
::
invalid_argument
(
"Group size should be a multiple of 32"
);
if
(
num_out_channels
%
group_size
!=
0
)
throw
std
::
invalid_argument
(
"OC is not multiple of Group size"
);
const
cudaStream_t
stream
=
at
::
cuda
::
getCurrentCUDAStream
();
if
(
num_out_channels
%
128
==
0
)
{
int
j_factors1
=
num_out_channels
/
128
/
1
;
dim3
num_blocks
((
num_out_
feat
s
+
1
6
-
1
)
/
16
*
j_factors1
*
split_k_iters
)
;
// threadIdx.x: 32
// threadIdx.
y
:
i_factors[2] * j_factors[2]
dim3
thread
s_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
128
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
else
if
(
num_out_channels
%
64
==
0
)
{
int
j_factors1
=
num_out_channels
/
64
/
1
;
dim3
num_blocks
(
1
*
(
num_out_feats
+
16
-
1
)
/
16
*
j_factors1
*
split_k_iters
)
;
// threadIdx.x: 32
// threadIdx.
y
:
i_factors[2] * j_factors[2]
dim3
thread
s_per_block
(
32
,
2
);
vllm
::
awq
::
gemm_forward_4bit_cuda_m16nXk32
<
64
>
<<<
num_blocks
,
threads_per_block
,
0
,
stream
>>>
(
group_size
,
split_k_iters
,
in_feats
,
kernel
,
scaling_factors
,
zeros
,
num_in_feats
,
num_in_channels
,
num_out_channels
,
out_feats
);
}
return
_out_feats
.
sum
(
0
);
}
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c2x.cu
View file @
5f6d10c1
...
...
@@ -117,10 +117,10 @@ struct cutlass_2x_gemm {
};
template
<
typename
Gemm
>
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
...
...
@@ -136,9 +136,9 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using
StrideC
=
Stride
<
int64_t
,
Int
<
1
>
,
Int
<
0
>>
;
StrideC
c_stride
{
ldc
,
Int
<
1
>
{},
Int
<
0
>
{}};
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
const
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
const
*>
(
b
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
a_scales_ptr
=
a_scales
.
data_ptr
<
float
>
();
auto
b_scales_ptr
=
b_scales
.
data_ptr
<
float
>
();
...
...
@@ -196,10 +196,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
}
// namespace
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -223,10 +223,10 @@ void cutlass_scaled_mm_dq_sm75(torch::Tensor &out, torch::Tensor const &a,
}
}
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
b
.
dtype
()
==
torch
::
kInt8
);
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
@@ -250,10 +250,10 @@ void cutlass_scaled_mm_dq_sm80(torch::Tensor &out, torch::Tensor const &a,
}
}
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
TileShape
=
typename
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
;
using
WarpShape
=
typename
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
;
using
InstructionShape
=
typename
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_dq_c3x.cu
View file @
5f6d10c1
...
...
@@ -120,10 +120,10 @@ struct cutlass_3x_gemm {
};
template
<
typename
Gemm
>
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_dispatcher
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
using
ElementAB
=
typename
Gemm
::
ElementAB
;
using
ElementD
=
typename
Gemm
::
ElementD
;
...
...
@@ -146,12 +146,12 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
using
GemmKernel
=
typename
Gemm
::
GemmKernel
;
typename
GemmKernel
::
ProblemShape
prob_shape
{
m
,
n
,
k
,
1
};
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
auto
a_ptr
=
static_cast
<
ElementAB
*>
(
a
.
data_ptr
());
auto
b_ptr
=
static_cast
<
ElementAB
*>
(
b
.
data_ptr
());
typename
GemmKernel
::
MainloopArguments
mainloop_args
{
a_ptr
,
a_stride
,
b_ptr
,
b_stride
};
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
auto
c_ptr
=
static_cast
<
ElementD
*>
(
out
.
data_ptr
());
typename
GemmKernel
::
EpilogueArguments
epilogue_args
{
{},
c_ptr
,
c_stride
,
c_ptr
,
c_stride
};
...
...
@@ -183,10 +183,10 @@ void cutlass_scaled_mm_dq_dispatcher(torch::Tensor &out, torch::Tensor const &a,
}
}
// namespace
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
out
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
TORCH_CHECK
(
a_scales
.
dtype
()
==
torch
::
kFloat32
);
TORCH_CHECK
(
b_scales
.
dtype
()
==
torch
::
kFloat32
);
...
...
csrc/quantization/cutlass_w8a8/scaled_mm_dq_entry.cu
View file @
5f6d10c1
...
...
@@ -2,29 +2,29 @@
#include <cuda_runtime.h>
#include <torch/extension.h>
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm75
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm80
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm89
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq_sm90
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
);
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
void
cutlass_scaled_mm_dq
(
torch
::
Tensor
&
c
,
torch
::
Tensor
const
&
a
,
torch
::
Tensor
const
&
b
,
torch
::
Tensor
const
&
a_scales
,
torch
::
Tensor
const
&
b_scales
)
{
int32_t
major_capability
;
int32_t
minor_capability
;
cudaDeviceGetAttribute
(
&
major_capability
,
cudaDevAttrComputeCapabilityMajor
,
...
...
@@ -36,14 +36,15 @@ void cutlass_scaled_mm_dq(torch::Tensor &c, torch::Tensor const &a,
// Checks for conformality
TORCH_CHECK
(
a
.
dim
()
==
2
&&
b
.
dim
()
==
2
&&
c
.
dim
()
==
2
);
TORCH_CHECK
(
c
.
size
(
0
)
==
a
.
size
(
0
)
&&
a
.
size
(
1
)
==
b
.
size
(
0
)
&&
b
.
size
(
1
)
==
c
.
size
(
1
));
b
.
size
(
1
)
==
c
.
size
(
1
));
TORCH_CHECK
(
a_scales
.
numel
()
==
1
||
a_scales
.
numel
()
==
a
.
size
(
0
));
TORCH_CHECK
(
b_scales
.
numel
()
==
1
||
b_scales
.
numel
()
==
b
.
size
(
1
));
// Check for strides and alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a
.
stride
(
1
)
==
1
&&
c
.
stride
(
1
)
==
1
);
// Row-major
TORCH_CHECK
(
b
.
stride
(
0
)
==
1
);
// Column-major
TORCH_CHECK
(
c
.
stride
(
0
)
%
16
==
0
&&
b
.
stride
(
1
)
%
16
==
0
);
// 16 Byte Alignment
TORCH_CHECK
(
a_scales
.
is_contiguous
()
&&
b_scales
.
is_contiguous
());
at
::
cuda
::
OptionalCUDAGuard
const
device_guard
(
device_of
(
a
));
...
...
csrc/quantization/fp8/amd/hip_float8.h
View file @
5f6d10c1
#pragma once
#ifdef __HIPCC__
#include <hip/hip_runtime.h>
#include <hip/hip_runtime.h>
#else
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#include <type_traits>
#include <stdint.h>
#include <math.h>
#include <iostream>
#endif
#include "hip_float8_impl.h"
struct
alignas
(
1
)
hip_fp8
{
struct
from_bits_t
{
};
HIP_FP8_HOST_DEVICE
static
constexpr
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
uint8_t
data
;
hip_fp8
()
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
const
hip_fp8
&
)
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
)
=
delete
;
explicit
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
,
from_bits_t
)
:
data
(
v
)
{
}
struct
alignas
(
1
)
hip_fp8
{
struct
from_bits_t
{};
HIP_FP8_HOST_DEVICE
static
constexpr
from_bits_t
from_bits
()
{
return
from_bits_t
();
}
uint8_t
data
;
hip_fp8
()
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
const
hip_fp8
&
)
=
default
;
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
)
=
delete
;
explicit
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
(
uint8_t
v
,
from_bits_t
)
:
data
(
v
)
{}
#ifdef __HIP__MI300__
// NOTE: ON-DEVICE... always optimal bias
explicit
HIP_FP8_DEVICE
hip_fp8
(
float
v
)
:
data
(
hip_fp8_impl
::
to_fp8_from_fp32
(
v
))
{
}
explicit
HIP_FP8_DEVICE
hip_fp8
(
_Float16
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{
}
// Host only implementation using s/w simulation
explicit
HIP_FP8_HOST
#else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation
explicit
HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8
(
float
v
)
{
data
=
hip_fp8_impl
::
to_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
);
}
explicit
HIP_FP8_HOST_DEVICE
hip_fp8
(
double
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{
}
// NOTE: ON-DEVICE... always optimal bias
explicit
HIP_FP8_DEVICE
hip_fp8
(
float
v
)
:
data
(
hip_fp8_impl
::
to_fp8_from_fp32
(
v
))
{}
explicit
HIP_FP8_DEVICE
hip_fp8
(
_Float16
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{}
// Host only implementation using s/w simulation
explicit
HIP_FP8_HOST
#else // __HIP__MI300__
// both Host and DEVICE for non-MI300 using s/w simulation
explicit
HIP_FP8_HOST_DEVICE
#endif // __HIP__MI300__
hip_fp8
(
float
v
)
{
data
=
hip_fp8_impl
::
to_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
,
true
/*clip*/
>
(
v
);
}
explicit
HIP_FP8_HOST_DEVICE
hip_fp8
(
double
v
)
:
hip_fp8
(
static_cast
<
float
>
(
v
))
{}
#ifdef __HIP__MI300__
// upcast using device specific intrinsic
explicit
inline
HIP_FP8_DEVICE
operator
float
()
const
{
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
asm
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
return
fval
;
}
explicit
inline
HIP_FP8_HOST
operator
float
()
const
#else // __HIP__MI300__
explicit
inline
HIP_FP8_HOST_DEVICE
operator
float
()
const
#endif // __HIP__MI300__
{
return
hip_fp8_impl
::
from_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
// upcast using device specific intrinsic
explicit
inline
HIP_FP8_DEVICE
operator
float
()
const
{
float
fval
;
uint32_t
i32val
=
static_cast
<
uint32_t
>
(
data
);
// upcast
asm
volatile
(
"v_cvt_f32_fp8 %0, %1 src0_sel:BYTE_0"
:
"=v"
(
fval
)
:
"v"
(
i32val
));
return
fval
;
}
explicit
inline
HIP_FP8_HOST
operator
float
()
const
#else // __HIP__MI300__
explicit
inline
HIP_FP8_HOST_DEVICE
operator
float
()
const
#endif // __HIP__MI300__
{
return
hip_fp8_impl
::
from_float8
<
4
,
3
,
float
,
true
/*negative_zero_nan*/
>
(
data
);
}
};
namespace
std
{
inline
hip_fp8
sin
(
hip_fp8
a
)
{
return
hip_fp8
(
sinf
(
float
(
a
)));
}
inline
hip_fp8
cos
(
hip_fp8
a
)
{
return
hip_fp8
(
cosf
(
float
(
a
)));
}
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
real
(
const
hip_fp8
&
a
)
{
return
a
;
}
}
// namespace std
namespace
std
{
inline
hip_fp8
sin
(
hip_fp8
a
)
{
return
hip_fp8
(
sinf
(
float
(
a
)));
}
inline
hip_fp8
cos
(
hip_fp8
a
)
{
return
hip_fp8
(
cosf
(
float
(
a
)));
}
HIP_FP8_HOST_DEVICE
constexpr
hip_fp8
real
(
const
hip_fp8
&
a
)
{
return
a
;
}
}
// namespace std
// Special operator overloading
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
hip_fp8
&
f8
)
{
return
os
<<
float
(
f8
);
inline
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
hip_fp8
&
f8
)
{
return
os
<<
float
(
f8
);
}
// all + operator overloading with mixed types
// mixed types, always converts to f32, does computation in f32, and returns
float
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
const
float
fa
,
hip_fp8
b
)
{
return
(
fa
+
float
(
b
));
// mixed types, always converts to f32, does computation in f32, and returns
// float
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
const
float
fa
,
hip_fp8
b
)
{
return
(
fa
+
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
hip_fp8
a
,
const
float
fb
)
{
return
(
float
(
a
)
+
fb
);
inline
HIP_FP8_HOST_DEVICE
float
operator
+
(
hip_fp8
a
,
const
float
fb
)
{
return
(
float
(
a
)
+
fb
);
}
inline
HIP_FP8_HOST_DEVICE
hip_fp8
operator
+
(
hip_fp8
a
,
hip_fp8
b
)
{
return
hip_fp8
(
float
(
a
)
+
float
(
b
));
inline
HIP_FP8_HOST_DEVICE
hip_fp8
operator
+
(
hip_fp8
a
,
hip_fp8
b
)
{
return
hip_fp8
(
float
(
a
)
+
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
hip_fp8
&
operator
+=
(
hip_fp8
&
a
,
hip_fp8
b
)
{
return
a
=
hip_fp8
(
float
(
a
)
+
float
(
b
));
inline
HIP_FP8_HOST_DEVICE
hip_fp8
&
operator
+=
(
hip_fp8
&
a
,
hip_fp8
b
)
{
return
a
=
hip_fp8
(
float
(
a
)
+
float
(
b
));
}
// overloading multiplication, always returns float,
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
hip_fp8
b
)
{
return
float
(
a
)
*
float
(
b
);
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
hip_fp8
b
)
{
return
float
(
a
)
*
float
(
b
);
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
float
a
,
hip_fp8
b
)
{
return
(
a
*
float
(
b
));
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
float
a
,
hip_fp8
b
)
{
return
(
a
*
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
float
b
)
{
return
(
float
(
a
)
*
b
);
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
hip_fp8
a
,
float
b
)
{
return
(
float
(
a
)
*
b
);
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
int32_t
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
int32_t
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
}
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
double
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
inline
HIP_FP8_HOST_DEVICE
float
operator
*
(
double
a
,
hip_fp8
b
)
{
return
((
float
)
a
*
float
(
b
));
}
// overloading for compare
inline
HIP_FP8_HOST_DEVICE
bool
operator
==
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
==
b
.
data
);
inline
HIP_FP8_HOST_DEVICE
bool
operator
==
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
==
b
.
data
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
!=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
!=
b
.
data
);
inline
HIP_FP8_HOST_DEVICE
bool
operator
!=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
(
a
.
data
!=
b
.
data
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
>=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>=
static_cast
<
float
>
(
b
);
inline
HIP_FP8_HOST_DEVICE
bool
operator
>=
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>=
static_cast
<
float
>
(
b
);
}
inline
HIP_FP8_HOST_DEVICE
bool
operator
>
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
inline
HIP_FP8_HOST_DEVICE
bool
operator
>
(
hip_fp8
a
,
hip_fp8
b
)
{
return
static_cast
<
float
>
(
a
)
>
static_cast
<
float
>
(
b
);
}
csrc/quantization/fp8/amd/hip_float8_impl.h
View file @
5f6d10c1
#pragma once
#if defined(__HIPCC__) && (defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#if defined(__HIPCC__) && \
(defined(__gfx940__) || defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300__
#endif
#ifdef __HIPCC__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#define HIP_FP8_HOST_DEVICE __host__ __device__
#define HIP_FP8_HOST __host__
#define HIP_FP8_DEVICE __device__
#else
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#define HIP_FP8_HOST_DEVICE
#define HIP_FP8_HOST
#define HIP_FP8_DEVICE
#endif
namespace
hip_fp8_impl
{
namespace
hip_fp8_impl
{
#ifdef __HIP__MI300__
HIP_FP8_DEVICE
uint8_t
to_fp8_from_fp32
(
float
v
)
{
u
int8_t
i8data
;
union
{
float
f
val
;
uint
32
_t
i
32
val
;
uint8_t
i8val
[
4
];
// NOTE: not endian independent
}
val
;
uint32_t
i
val
=
0
;
val
.
fval
=
v
;
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
return
i8data
;
HIP_FP8_DEVICE
uint8_t
to_fp8_from_fp32
(
float
v
)
{
uint8_t
i8data
;
u
nion
{
float
fval
;
uint32_t
i32
val
;
uint
8
_t
i
8
val
[
4
];
// NOTE: not endian independent
}
val
;
uint32_t
ival
=
0
;
val
.
f
val
=
v
;
if
((
val
.
i32val
&
0x7F800000
)
!=
0x7F800000
)
{
/// propagate NAN/INF, no clipping
val
.
fval
=
__builtin_amdgcn_fmed3f
(
val
.
fval
,
240.0
,
-
240.0
);
}
ival
=
__builtin_amdgcn_cvt_pk_fp8_f32
(
val
.
fval
,
val
.
fval
,
ival
,
false
);
// false -> WORD0
val
.
i32val
=
ival
;
i8data
=
val
.
i8val
[
0
];
return
i8data
;
}
#endif // __HIP__MI300__
#endif
// __HIP__MI300__
HIP_FP8_HOST
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
HIP_FP8_HOST
inline
int
clz
(
uint32_t
x
)
{
return
__builtin_clz
(
x
);
}
#if defined(__HIPCC__) || defined(__CUDA_ARCH__)
HIP_FP8_DEVICE
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
HIP_FP8_DEVICE
inline
int
clz
(
uint32_t
x
)
{
return
__clz
(
x
);
}
#endif
template
<
int
we
,
int
wm
,
typename
T
,
bool
negative_zero_nan
,
bool
clip
>
HIP_FP8_HOST_DEVICE
uint8_t
to_float8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
HIP_FP8_HOST_DEVICE
uint8_t
to_float8
(
T
_x
,
bool
stoch
=
false
,
uint32_t
rng
=
0
)
{
#ifdef __HIPCC__
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
#else
constexpr
bool
is_half
=
false
;
constexpr
bool
is_half
=
false
;
#endif
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
wm
+
we
==
7
,
"wm+we==7"
);
static_assert
(
is_half
||
is_float
,
"Only half and float can be cast to f8"
);
const
int
mfmt
=
(
sizeof
(
T
)
==
4
)
?
23
:
10
;
uint32_t
x
;
if
(
sizeof
(
T
)
==
4
)
{
x
=
reinterpret_cast
<
uint32_t
&>
(
_x
);
}
else
{
x
=
reinterpret_cast
<
uint16_t
&>
(
_x
);
}
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
uint32_t
sign
;
if
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
// Deal with inf and NaNs
if
(
negative_zero_nan
)
{
if
(
sizeof
(
T
)
==
4
)
{
x
=
reinterpret_cast
<
uint32_t
&>
(
_x
);
if
((
x
&
0x7F800000
)
==
0x7F800000
)
{
return
0x80
;
}
}
else
{
x
=
reinterpret_cast
<
uint16_t
&>
(
_x
);
// if(__hisinf(x) || __hisnan(x))
if
((
x
&
0x7C00
)
==
0x7C00
)
{
return
0x80
;
}
}
uint32_t
head
,
mantissa
;
int
exponent
,
bias
;
uint32_t
sign
;
}
else
{
if
(
sizeof
(
T
)
==
4
)
{
head
=
x
&
0xFF800000
;
mantissa
=
x
&
0x7FFFFF
;
exponent
=
(
head
>>
23
)
&
0xFF
;
sign
=
head
>>
31
;
bias
=
127
;
if
((
x
&
0x7F800000
)
==
0x7F800000
)
{
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
else
{
head
=
x
&
0xFC00
;
mantissa
=
x
&
0x3FF
;
exponent
=
(
head
>>
10
)
&
0x1F
;
sign
=
head
>>
15
;
bias
=
15
;
if
((
x
&
0x7C00
)
==
0x7C00
)
{
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
uint32_t
signed_inf
=
(
sign
<<
7
)
+
(((
1
<<
we
)
-
1
)
<<
wm
);
// Deal with inf and NaNs
if
(
negative_zero_nan
)
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
{
return
0x80
;
}
}
else
{
// if(__hisinf(x) || __hisnan(x))
if
((
x
&
0x7C00
)
==
0x7C00
)
{
return
0x80
;
}
}
}
else
{
if
(
sizeof
(
T
)
==
4
)
{
if
((
x
&
0x7F800000
)
==
0x7F800000
)
{
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
else
{
if
((
x
&
0x7C00
)
==
0x7C00
)
{
return
signed_inf
+
(
mantissa
!=
0
?
1
:
0
);
}
}
}
if
(
x
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
f8_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
}
if
(
x
==
0
)
{
return
0
;
}
// First need to check if it is normal or denorm as there is a difference of
// implicit 1 Then need to adjust the exponent to align with the F8 exponent,
// in the meanwhile, shift The mantissa. Then for stochastic rounding, add rng
// to mantissa and truncate. And for RNE, no need to add rng. Then probably
// need to check whether there is carry and adjust exponent and mantissa again
// For IEEE bias mode, the bias is 2^(k-1) -1 where k is the width of exponent
// bits
const
int
f8_bias
=
(
1
<<
(
we
-
1
))
-
1
+
(
negative_zero_nan
?
1
:
0
);
const
int
f8_denormal_act_exponent
=
1
-
f8_bias
;
// actual exponent of f8 denormal
// act_exponent is the actual exponent of fp32/fp16 (after subtracting bias)
// f8_exponent is the converted f8 exponent with bias encoding
// exponent_diff is the diff between fp32/fp16 exponent and f8 exponent,
// the difference needs to be adjusted and mantissa shifted
int
act_exponent
,
f8_exponent
,
exponent_diff
;
if
(
exponent
==
0
)
{
// fp32/fp16 is in denormal.
/* fp32 denormal is below 2^-127 so it is usually not a concern here, we
mostly concern fp16 here. In this case, f8 is usually in denormal. But there
could be exceptions. fp16 denormal has exponent bias 15 while bf8 with NANOO has
exponent bias 16. It means that there are some numbers in fp16 denormal but they
are bf8 (NANOO) normals - smallest bf8 (NANOO) normal is 2^-15. fp16 numbers
where exponent==0 (actual exponent -14) and highest bit of mantissa is 1 are bf8
(NANOO) normal. In this case, the fp16 mantissa should be shift left by 1 */
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implici
t
1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_denormal_act_exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode,
denormal
exponent
is -7, but if the
fp32/fp16 actual exponent is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to -6 and ma
ntis
sa shift right by 1.
So for
fp32/fp16
,
exponent
-8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
a
ct_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent_diff=0 does not mean there is no difference
// for this case,
//
act_
exponent
could be larger. Just that it does not need shift mantissa
}
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into
mantissa
act_exponent
=
exponent
-
bias
+
1
;
exponent_diff
=
f8_denormal_act_exponen
t
-
act_exponent
;
// actual exponent is exponent-bias+1 as it is denormal
}
else
{
// fp32/fp16 is normal with implicit 1
act_exponent
=
exponent
-
bias
;
if
(
act_exponent
<=
f8_
denormal
_act_
exponent
)
{
/* This is the case where fp32/fp16 is normal but it is in f8 denormal
range. For example fp8 nanoo mode, denormal expone
nt
is
-7, but if the
fp32/fp16
actual
exponent
is -7, it is actually larger due to the implicit 1,
Therefore it needs to be adjust to
-
6
a
nd mantissa shift right by 1.
So for fp32/fp16, exponent -8 is the cut point to convert to fp8 nanoo */
exponent_diff
=
f8_denormal_act_exponent
-
act_exponent
;
}
else
{
// both fp32/fp16 and f8 are in normal range
exponent_diff
=
0
;
// exponent
_diff=0 does not mean there is no
// difference for this case, act_exponent could be
// larger. Just that it does not need shift
mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
mfmt
-
wm
+
exponent_diff
))
-
1
))
==
static_cast
<
uint32_t
>
(
1
<<
(
mfmt
-
wm
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint.
mantissa
+=
(
1
<<
mfmt
);
// Add the implicit 1 into mantissa
}
bool
midpoint
=
(
mantissa
&
((
1
<<
(
mfmt
-
wm
+
exponent_diff
))
-
1
))
==
static_cast
<
uint32_t
>
(
1
<<
(
mfmt
-
wm
+
exponent_diff
-
1
));
/* This part is a bit tricky. The judgment of whether it is a tie needs to be
done before we shift right as shift right could rip off some residual part
and make something not midpoint look like midpoint. For example, the fp16
number 0x1002 (0 00100 0000000010), it is larger than midpoint, but after
shift right by 4 bits, it would look like midpoint.
*/
if
(
exponent_diff
>
0
)
{
mantissa
>>=
exponent_diff
;
}
else
if
(
exponent_diff
==
-
1
)
{
mantissa
<<=
-
exponent_diff
;
if
(
exponent_diff
>
0
)
{
mantissa
>>=
exponent_diff
;
}
else
if
(
exponent_diff
==
-
1
)
{
mantissa
<<=
-
exponent_diff
;
}
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit
// that is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
)
{
if
((
1
<<
mfmt
)
&
mantissa
)
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
bool
implicit_one
=
mantissa
&
(
1
<<
mfmt
);
// if there is no implicit 1, it means the f8 is denormal and need to adjust
// to denorm exponent
f8_exponent
=
(
act_exponent
+
exponent_diff
)
/*actual f8 exponent*/
+
f8_bias
-
(
implicit_one
?
0
:
1
);
// Now we have the exponent and mantissa adjusted
uint32_t
drop_mask
=
(
1
<<
(
mfmt
-
wm
))
-
1
;
bool
odd
=
mantissa
&
(
1
<<
(
mfmt
-
wm
));
// if the least significant bit that
// is not truncated is 1
mantissa
+=
(
stoch
?
rng
:
(
midpoint
?
(
odd
?
mantissa
:
mantissa
-
1
)
:
mantissa
))
&
drop_mask
;
// Now we deal with overflow
if
(
f8_exponent
==
0
)
{
if
((
1
<<
mfmt
)
&
mantissa
)
{
f8_exponent
=
1
;
// denormal overflow to become normal, promote exponent
}
}
else
{
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
}
}
else
{
if
((
1
<<
(
mfmt
+
1
))
&
mantissa
)
{
mantissa
>>=
1
;
f8_exponent
++
;
}
}
mantissa
>>=
(
mfmt
-
wm
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
}
mantissa
>>=
(
mfmt
-
wm
);
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
{
return
negative_zero_nan
?
0
:
(
sign
<<
7
);
// above range: quantize to maximum possible float of the same sign
const
int
max_exp
=
(
1
<<
we
)
-
(
negative_zero_nan
?
1
:
2
);
if
(
f8_exponent
>
max_exp
)
{
if
(
clip
)
{
mantissa
=
(
1
<<
wm
)
-
1
;
f8_exponent
=
max_exp
;
}
else
{
return
signed_inf
;
}
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
wm
)
|
mantissa
;
}
if
(
f8_exponent
==
0
&&
mantissa
==
0
)
{
return
negative_zero_nan
?
0
:
(
sign
<<
7
);
}
mantissa
&=
(
1
<<
wm
)
-
1
;
return
(
sign
<<
7
)
|
(
f8_exponent
<<
wm
)
|
mantissa
;
}
template
<
int
we
,
int
wm
,
typename
T
=
float
,
bool
negative_zero_nan
=
true
>
inline
HIP_FP8_HOST_DEVICE
T
from_float8
(
uint8_t
x
)
{
inline
HIP_FP8_HOST_DEVICE
T
from_float8
(
uint8_t
x
)
{
#ifdef __HIPCC__
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
constexpr
bool
is_half
=
std
::
is_same
<
T
,
_Float16
>::
value
;
#else
constexpr
bool
is_half
=
false
;
constexpr
bool
is_half
=
false
;
#endif
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
bool
is_float
=
std
::
is_same
<
T
,
float
>::
value
;
static_assert
(
is_half
||
is_float
,
"only half and float are supported"
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
constexpr
int
weo
=
is_half
?
5
:
8
;
constexpr
int
wmo
=
is_half
?
10
:
(
is_float
?
23
:
7
);
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
T
fInf
,
fNegInf
,
fNaN
,
fNeg0
;
#ifdef __HIPCC__
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
_Float16
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
_Float16
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
_Float16
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
_Float16
&>
(
ihNeg0
);
}
else
if
(
is_half
)
{
const
uint16_t
ihInf
=
0x7C00
;
const
uint16_t
ihNegInf
=
0xFC00
;
const
uint16_t
ihNaN
=
0x7C01
;
const
uint16_t
ihNeg0
=
0x8000
;
fInf
=
reinterpret_cast
<
const
_Float16
&>
(
ihInf
);
fNegInf
=
reinterpret_cast
<
const
_Float16
&>
(
ihNegInf
);
fNaN
=
reinterpret_cast
<
const
_Float16
&>
(
ihNaN
);
fNeg0
=
reinterpret_cast
<
const
_Float16
&>
(
ihNeg0
);
}
else
#endif
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
{
return
0
;
}
uint32_t
sign
=
x
>>
7
;
uint32_t
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
(
negative_zero_nan
)
{
if
(
x
==
0x80
)
{
return
fNaN
;
}
}
else
{
if
(
x
==
0x80
)
{
return
fNeg0
;
}
if
(
exponent
==
((
1
<<
we
)
-
1
))
{
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
}
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
if
(
is_float
)
{
const
uint32_t
ifInf
=
0x7F800000
;
const
uint32_t
ifNegInf
=
0xFF800000
;
const
uint32_t
ifNaN
=
0x7F800001
;
const
uint32_t
ifNeg0
=
0x80000000
;
fInf
=
reinterpret_cast
<
const
float
&>
(
ifInf
);
fNegInf
=
reinterpret_cast
<
const
float
&>
(
ifNegInf
);
fNaN
=
reinterpret_cast
<
const
float
&>
(
ifNaN
);
fNeg0
=
reinterpret_cast
<
const
float
&>
(
ifNeg0
);
}
if
(
x
==
0
)
{
return
0
;
}
uint32_t
sign
=
x
>>
7
;
uint32_t
mantissa
=
x
&
((
1
<<
wm
)
-
1
);
int
exponent
=
(
x
&
0x7F
)
>>
wm
;
if
(
negative_zero_nan
)
{
if
(
x
==
0x80
)
{
return
fNaN
;
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
wm
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
wm
)
-
1
);
}
else
{
if
(
x
==
0x80
)
{
return
fNeg0
;
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
if
(
sizeof
(
T
)
==
2
)
{
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
}
else
{
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
if
(
exponent
==
((
1
<<
we
)
-
1
))
{
return
(
mantissa
==
0
)
?
(
sign
?
fNegInf
:
fInf
)
:
fNaN
;
}
}
typename
std
::
conditional
<
sizeof
(
T
)
==
2
,
uint16_t
,
uint32_t
>::
type
retval
;
if
(
we
==
5
&&
is_half
&&
!
negative_zero_nan
)
{
retval
=
x
<<
8
;
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
const
int
exp_low_cutoff
=
(
1
<<
(
weo
-
1
))
-
(
1
<<
(
we
-
1
))
+
1
-
(
negative_zero_nan
?
1
:
0
);
// subnormal input
if
(
exponent
==
0
)
{
// guaranteed mantissa!=0 since cases 0x0 and 0x80 are handled above
int
sh
=
1
+
clz
(
mantissa
)
-
(
32
-
wm
);
mantissa
<<=
sh
;
exponent
+=
1
-
sh
;
mantissa
&=
((
1
<<
wm
)
-
1
);
}
exponent
+=
exp_low_cutoff
-
1
;
mantissa
<<=
wmo
-
wm
;
// subnormal output (occurs when T=half, we=5, negative_zero_nan=true)
if
(
exponent
<=
0
)
{
mantissa
|=
1
<<
wmo
;
mantissa
>>=
1
-
exponent
;
exponent
=
0
;
}
if
(
sizeof
(
T
)
==
2
)
{
retval
=
(
sign
<<
15
)
|
(
exponent
<<
10
)
|
mantissa
;
}
else
{
retval
=
(
sign
<<
31
)
|
(
exponent
<<
23
)
|
mantissa
;
}
return
reinterpret_cast
<
const
T
&>
(
retval
);
}
}
// namespace hip_fp8_impl
}
// namespace hip_fp8_impl
Prev
1
2
3
4
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment