Commit 4217976d authored by zhushuang's avatar zhushuang
Browse files

feat: rename Dequantize to DequantizeAWQ in nvidia gpu

parent d3d982df
...@@ -7,7 +7,7 @@ ...@@ -7,7 +7,7 @@
#include "infiniop/ops/causal_softmax.h" #include "infiniop/ops/causal_softmax.h"
#include "infiniop/ops/clip.h" #include "infiniop/ops/clip.h"
#include "infiniop/ops/conv.h" #include "infiniop/ops/conv.h"
#include "infiniop/ops/dequantize.h" #include "infiniop/ops/dequantize_awq.h"
#include "infiniop/ops/gemm.h" #include "infiniop/ops/gemm.h"
#include "infiniop/ops/mul.h" #include "infiniop/ops/mul.h"
#include "infiniop/ops/random_sample.h" #include "infiniop/ops/random_sample.h"
......
#ifndef __INFINIOP_DEQUANTIZE_API_H__
#define __INFINIOP_DEQUANTIZE_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopDequantizeDescriptor_t;
__C __export infiniStatus_t infiniopCreateDequantizeDescriptor(infiniopHandle_t handle,
infiniopDequantizeDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc);
__C __export infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopDequantize(infiniopDequantizeDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *qweight,
const void *scales,
const void *zeros,
void *stream);
__C __export infiniStatus_t infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc);
#endif
#ifndef __INFINIOP_DEQUANTIZE_AWQ_API_H__
#define __INFINIOP_DEQUANTIZE_AWQ_API_H__
#include "../operator_descriptor.h"
typedef struct InfiniopDescriptor *infiniopDequantizeAWQDescriptor_t;
__C __export infiniStatus_t infiniopCreateDequantizeAWQDescriptor(infiniopHandle_t handle,
infiniopDequantizeAWQDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc);
__C __export infiniStatus_t infiniopGetDequantizeAWQWorkspaceSize(infiniopDequantizeAWQDescriptor_t desc, size_t *size);
__C __export infiniStatus_t infiniopDequantizeAWQ(infiniopDequantizeAWQDescriptor_t desc,
void *workspace,
size_t workspace_size,
void *out,
const void *qweight,
const void *scales,
const void *zeros,
void *stream);
__C __export infiniStatus_t infiniopDestroyDequantizeAWQDescriptor(infiniopDequantizeAWQDescriptor_t desc);
#endif
#ifndef __DEQUANTIZE_CUDA_CUH__
#define __DEQUANTIZE_CUDA_CUH__
#include "../dequantize.h"
DESCRIPTOR(nvidia)
#endif // __GEMM_CUDA_CUH__
#ifndef __DEQUANTIZE_H__ #ifndef __DEQUANTIZE_AWQ_H__
#define __DEQUANTIZE_H__ #define __DEQUANTIZE_AWQ_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../operator.h" #include "../../operator.h"
...@@ -8,17 +8,17 @@ ...@@ -8,17 +8,17 @@
#define DESCRIPTOR(NAMESPACE) \ #define DESCRIPTOR(NAMESPACE) \
\ \
namespace op::dequantize::NAMESPACE { \ namespace op::dequantize_awq::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \ class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \ struct Opaque; \
Opaque *_opaque; \ Opaque *_opaque; \
DequantizeInfo _info; \ DequantizeAWQInfo _info; \
size_t _workspace_size; \ size_t _workspace_size; \
\ \
Descriptor( \ Descriptor( \
size_t workspace_size_, \ size_t workspace_size_, \
Opaque *opaque, \ Opaque *opaque, \
DequantizeInfo info, \ DequantizeAWQInfo info, \
infiniDevice_t device_type, \ infiniDevice_t device_type, \
int device_id) \ int device_id) \
: InfiniopDescriptor{device_type, device_id}, \ : InfiniopDescriptor{device_type, device_id}, \
...@@ -49,4 +49,5 @@ ...@@ -49,4 +49,5 @@
void *stream) const; \ void *stream) const; \
}; \ }; \
} }
#endif
#endif //__DEQUANTIZE_AWQ_H__
#ifndef __DEQUANTIZE_INFO_H__ #ifndef __DEQUANTIZE_AWQ_INFO_H__
#define __DEQUANTIZE_INFO_H__ #define __DEQUANTIZE_AWQ_INFO_H__
#include "../../../utils.h" #include "../../../utils.h"
#include "../../tensor.h" #include "../../tensor.h"
#include <vector> #include <vector>
namespace op::dequantize { namespace op::dequantize_awq {
class DequantizeInfo { class DequantizeAWQInfo {
DequantizeInfo() = default; DequantizeAWQInfo() = default;
public: public:
int _in_features, _out_features, _num_groups; int _in_features, _out_features, _num_groups;
...@@ -17,7 +17,7 @@ public: ...@@ -17,7 +17,7 @@ public:
int out_features() const { return _out_features; } int out_features() const { return _out_features; }
int num_groups() const { return _num_groups; } int num_groups() const { return _num_groups; }
static utils::Result<DequantizeInfo> create( static utils::Result<DequantizeAWQInfo> create(
infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc, infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc, infiniopTensorDescriptor_t scales_desc,
...@@ -27,13 +27,13 @@ public: ...@@ -27,13 +27,13 @@ public:
int _out_features = qweight_desc->dim(1); int _out_features = qweight_desc->dim(1);
int _num_groups = scales_desc->dim(0); int _num_groups = scales_desc->dim(0);
return utils::Result<DequantizeInfo>(DequantizeInfo{ return utils::Result<DequantizeAWQInfo>(DequantizeAWQInfo{
_in_features, _in_features,
_out_features, _out_features,
_num_groups}); _num_groups});
} }
}; };
} // namespace op::dequantize } // namespace op::dequantize_awq
#endif // __DEQUANTIZE_INFO_H__ #endif // __DEQUANTIZE_AWQ_INFO_H__
...@@ -5,7 +5,7 @@ ...@@ -5,7 +5,7 @@
#include "dequantize_w42f16_kernel.cuh" #include "dequantize_w42f16_kernel.cuh"
#include "dequantize_w42f16_nvidia.cuh" #include "dequantize_w42f16_nvidia.cuh"
#include "../dequantize.h" #include "../dequantize_awq.h"
#include <cuda_fp16.h> #include <cuda_fp16.h>
__global__ void __launch_bounds__(64) __global__ void __launch_bounds__(64)
...@@ -68,7 +68,7 @@ __global__ void __launch_bounds__(64) ...@@ -68,7 +68,7 @@ __global__ void __launch_bounds__(64)
} }
} }
namespace op::dequantize::nvidia { namespace op::dequantize_awq::nvidia {
struct Descriptor::Opaque { struct Descriptor::Opaque {
std::shared_ptr<device::nvidia::Handle::Internal> internal; std::shared_ptr<device::nvidia::Handle::Internal> internal;
...@@ -87,7 +87,7 @@ infiniStatus_t Descriptor::create( ...@@ -87,7 +87,7 @@ infiniStatus_t Descriptor::create(
infiniopTensorDescriptor_t zeros_desc) { infiniopTensorDescriptor_t zeros_desc) {
auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_); auto handle = reinterpret_cast<device::nvidia::Handle *>(handle_);
auto result = DequantizeInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc); auto result = DequantizeAWQInfo::create(out_desc, qweight_desc, scales_desc, zeros_desc);
*desc_ptr = new Descriptor( *desc_ptr = new Descriptor(
0, 0,
...@@ -133,6 +133,6 @@ Descriptor::calculate( ...@@ -133,6 +133,6 @@ Descriptor::calculate(
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
} }
} // namespace op::dequantize::nvidia } // namespace op::dequantize_awq::nvidia
#endif #endif
#ifndef __DEQUANTIZE_AWQ_CUDA_CUH__
#define __DEQUANTIZE_AWQ_CUDA_CUH__
#include "../dequantize_awq.h"
DESCRIPTOR(nvidia)
#endif // __DEQUANTIZE_AWQ_CUDA_CUH__
#include "../../operator.h" #include "../../operator.h"
#include "../../handle.h" #include "../../handle.h"
#include "infiniop/ops/dequantize.h" #include "infiniop/ops/dequantize_awq.h"
#ifdef ENABLE_NVIDIA_API #ifdef ENABLE_NVIDIA_API
#include "nvidia/dequantize_w42f16_nvidia.cuh" #include "nvidia/dequantize_w42f16_nvidia.cuh"
#endif #endif
__C infiniStatus_t infiniopCreateDequantizeDescriptor( __C infiniStatus_t infiniopCreateDequantizeAWQDescriptor(
infiniopHandle_t handle, infiniopHandle_t handle,
infiniopDequantizeDescriptor_t *desc_ptr, infiniopDequantizeAWQDescriptor_t *desc_ptr,
infiniopTensorDescriptor_t out_desc, infiniopTensorDescriptor_t out_desc,
infiniopTensorDescriptor_t qweight_desc, infiniopTensorDescriptor_t qweight_desc,
infiniopTensorDescriptor_t scales_desc, infiniopTensorDescriptor_t scales_desc,
infiniopTensorDescriptor_t zeros_desc) { infiniopTensorDescriptor_t zeros_desc) {
#define CREATE(CASE, NAMESPACE) \ #define CREATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return op::dequantize::NAMESPACE::Descriptor::create( \ return op::dequantize_awq::NAMESPACE::Descriptor::create( \
handle, \ handle, \
reinterpret_cast<op::dequantize::NAMESPACE::Descriptor **>(desc_ptr), \ reinterpret_cast<op::dequantize_awq::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \ out_desc, \
qweight_desc, \ qweight_desc, \
scales_desc, \ scales_desc, \
zeros_desc) zeros_desc)
switch (handle->device) { switch (handle->device) {
...@@ -35,11 +35,11 @@ __C infiniStatus_t infiniopCreateDequantizeDescriptor( ...@@ -35,11 +35,11 @@ __C infiniStatus_t infiniopCreateDequantizeDescriptor(
#undef CREATE #undef CREATE
} }
__C infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescriptor_t desc, __C infiniStatus_t infiniopGetDequantizeAWQWorkspaceSize(infiniopDequantizeAWQDescriptor_t desc,
size_t *size) { size_t *size) {
#define GET(CASE, NAMESPACE) \ #define GET(CASE, NAMESPACE) \
case CASE: \ case CASE: \
*size = reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \ *size = reinterpret_cast<const op::dequantize_awq::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS return INFINI_STATUS_SUCCESS
switch (desc->device_type) { switch (desc->device_type) {
...@@ -52,8 +52,8 @@ __C infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescript ...@@ -52,8 +52,8 @@ __C infiniStatus_t infiniopGetDequantizeWorkspaceSize(infiniopDequantizeDescript
#undef GET #undef GET
} }
__C infiniStatus_t infiniopDequantize( __C infiniStatus_t infiniopDequantizeAWQ(
infiniopDequantizeDescriptor_t desc, infiniopDequantizeAWQDescriptor_t desc,
void *workspace, void *workspace,
size_t workspace_size, size_t workspace_size,
void *out, void *out,
...@@ -62,9 +62,9 @@ __C infiniStatus_t infiniopDequantize( ...@@ -62,9 +62,9 @@ __C infiniStatus_t infiniopDequantize(
const void *zeros, const void *zeros,
void *stream) { void *stream) {
#define CALCULATE(CASE, NAMESPACE) \ #define CALCULATE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
return reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc) \ return reinterpret_cast<const op::dequantize_awq::NAMESPACE::Descriptor *>(desc) \
->calculate(workspace, workspace_size, out, qweight, scales, zeros, stream) ->calculate(workspace, workspace_size, out, qweight, scales, zeros, stream)
switch (desc->device_type) { switch (desc->device_type) {
...@@ -79,11 +79,11 @@ __C infiniStatus_t infiniopDequantize( ...@@ -79,11 +79,11 @@ __C infiniStatus_t infiniopDequantize(
} }
__C infiniStatus_t __C infiniStatus_t
infiniopDestroyDequantizeDescriptor(infiniopDequantizeDescriptor_t desc) { infiniopDestroyDequantizeAWQDescriptor(infiniopDequantizeAWQDescriptor_t desc) {
#define DELETE(CASE, NAMESPACE) \ #define DELETE(CASE, NAMESPACE) \
case CASE: \ case CASE: \
delete reinterpret_cast<const op::dequantize::NAMESPACE::Descriptor *>(desc); \ delete reinterpret_cast<const op::dequantize_awq::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS; return INFINI_STATUS_SUCCESS;
switch (desc->device_type) { switch (desc->device_type) {
......
...@@ -140,7 +140,7 @@ AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7] ...@@ -140,7 +140,7 @@ AWQ_ORDER = [0, 2, 4, 6, 1, 3, 5, 7]
AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7] AWQ_REVERSE_ORDER = [0, 4, 1, 5, 2, 6, 3, 7]
def dequantize( def dequantize_awq(
qweight: torch.Tensor, qweight: torch.Tensor,
qzeros: torch.Tensor, qzeros: torch.Tensor,
qscales: torch.Tensor, qscales: torch.Tensor,
...@@ -216,7 +216,7 @@ def test( ...@@ -216,7 +216,7 @@ def test(
sync=None, sync=None,
): ):
print( print(
f"Testing Dequantize on {InfiniDeviceNames[device]} with bits:{bits}, group_size:{group_size}," f"Testing Dequantize AWQ on {InfiniDeviceNames[device]} with bits:{bits}, group_size:{group_size},"
f" qweights_shape:{qweights_shape}, qzeros_shape:{qzeros_shape}, qscales_shape:{qscales_shape}," f" qweights_shape:{qweights_shape}, qzeros_shape:{qzeros_shape}, qscales_shape:{qscales_shape},"
f" qweights_stride:{qweights_stride}, qzeros_stride:{qzeros_stride}, qscales_stride:{qscales_stride}," f" qweights_stride:{qweights_stride}, qzeros_stride:{qzeros_stride}, qscales_stride:{qscales_stride},"
f" qweights_dtype:{InfiniDtypeNames[qweights_dtype]}, qzeros_dtype:{InfiniDtypeNames[qzeros_dtype]}, qscales_dtype:{InfiniDtypeNames[qscales_dtype]}" f" qweights_dtype:{InfiniDtypeNames[qweights_dtype]}, qzeros_dtype:{InfiniDtypeNames[qzeros_dtype]}, qscales_dtype:{InfiniDtypeNames[qscales_dtype]}"
...@@ -225,14 +225,16 @@ def test( ...@@ -225,14 +225,16 @@ def test(
qweights = TestTensor( qweights = TestTensor(
qweights_shape, qweights_stride, qweights_dtype, device, mode="randint" qweights_shape, qweights_stride, qweights_dtype, device, mode="randint"
) )
qzeros = TestTensor(qzeros_shape, qzeros_stride, qzeros_dtype, device, mode="randint") qzeros = TestTensor(
qzeros_shape, qzeros_stride, qzeros_dtype, device, mode="randint"
)
qscales = TestTensor(qscales_shape, qscales_stride, qscales_dtype, device) qscales = TestTensor(qscales_shape, qscales_stride, qscales_dtype, device)
out = TestTensor(out_shape, out_stride, out_dtype, device, mode="zeros") out = TestTensor(out_shape, out_stride, out_dtype, device, mode="zeros")
ans = TestTensor(out_shape, out_stride, out_dtype, device, mode="ones") ans = TestTensor(out_shape, out_stride, out_dtype, device, mode="ones")
# Compute the PyTorch reference result # Compute the PyTorch reference result
def torch_dequantize(): def torch_dequantize_awq():
return dequantize( return dequantize_awq(
qweights.torch_tensor(), qweights.torch_tensor(),
qzeros.torch_tensor(), qzeros.torch_tensor(),
qscales.torch_tensor(), qscales.torch_tensor(),
...@@ -240,14 +242,14 @@ def test( ...@@ -240,14 +242,14 @@ def test(
group_size, group_size,
) )
ans = torch_dequantize() ans = torch_dequantize_awq()
if sync is not None: if sync is not None:
sync() sync()
descriptor = infiniopOperatorDescriptor_t() descriptor = infiniopOperatorDescriptor_t()
check_error( check_error(
LIBINFINIOP.infiniopCreateDequantizeDescriptor( LIBINFINIOP.infiniopCreateDequantizeAWQDescriptor(
handle, handle,
ctypes.byref(descriptor), ctypes.byref(descriptor),
out.descriptor, out.descriptor,
...@@ -264,16 +266,16 @@ def test( ...@@ -264,16 +266,16 @@ def test(
# Get workspace size and create workspace # Get workspace size and create workspace
workspace_size = c_uint64(0) workspace_size = c_uint64(0)
check_error( check_error(
LIBINFINIOP.infiniopGetDequantizeWorkspaceSize( LIBINFINIOP.infiniopGetDequantizeAWQWorkspaceSize(
descriptor, ctypes.byref(workspace_size) descriptor, ctypes.byref(workspace_size)
) )
) )
workspace = TestWorkspace(workspace_size.value, device) workspace = TestWorkspace(workspace_size.value, device)
# Execute infiniop gemm operator # Execute infiniop gemm operator
def lib_dequantize(): def lib_dequantize_awq():
check_error( check_error(
LIBINFINIOP.infiniopDequantize( LIBINFINIOP.infiniopDequantizeAWQ(
descriptor, descriptor,
workspace.data(), workspace.data(),
workspace_size.value, workspace_size.value,
...@@ -285,7 +287,7 @@ def test( ...@@ -285,7 +287,7 @@ def test(
) )
) )
lib_dequantize() lib_dequantize_awq()
# Validate results # Validate results
atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype) atol, rtol = get_tolerance(_TOLERANCE_MAP, dtype)
...@@ -298,10 +300,10 @@ def test( ...@@ -298,10 +300,10 @@ def test(
# Profiling workflow # Profiling workflow
if PROFILE: if PROFILE:
# fmt: off # fmt: off
profile_operation("PyTorch", lambda: torch_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation("PyTorch", lambda: torch_dequantize_awq(), device, NUM_PRERUN, NUM_ITERATIONS)
profile_operation(" lib", lambda: lib_dequantize(), device, NUM_PRERUN, NUM_ITERATIONS) profile_operation(" lib", lambda: lib_dequantize_awq(), device, NUM_PRERUN, NUM_ITERATIONS)
# fmt: on # fmt: on
check_error(LIBINFINIOP.infiniopDestroyDequantizeDescriptor(descriptor)) check_error(LIBINFINIOP.infiniopDestroyDequantizeAWQDescriptor(descriptor))
# ============================================================================== # ==============================================================================
......
...@@ -533,8 +533,8 @@ def topkrouter_(lib): ...@@ -533,8 +533,8 @@ def topkrouter_(lib):
@OpRegister.operator @OpRegister.operator
def dequantize_(lib): def dequantize_(lib):
lib.infiniopCreateDequantizeDescriptor.restype = c_int32 lib.infiniopCreateDequantizeAWQDescriptor.restype = c_int32
lib.infiniopCreateDequantizeDescriptor.argtypes = [ lib.infiniopCreateDequantizeAWQDescriptor.argtypes = [
infiniopHandle_t, infiniopHandle_t,
POINTER(infiniopOperatorDescriptor_t), POINTER(infiniopOperatorDescriptor_t),
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
...@@ -542,13 +542,13 @@ def dequantize_(lib): ...@@ -542,13 +542,13 @@ def dequantize_(lib):
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
infiniopTensorDescriptor_t, infiniopTensorDescriptor_t,
] ]
lib.infiniopGetDequantizeWorkspaceSize.restype = c_int32 lib.infiniopGetDequantizeAWQWorkspaceSize.restype = c_int32
lib.infiniopGetDequantizeWorkspaceSize.argtypes = [ lib.infiniopGetDequantizeAWQWorkspaceSize.argtypes = [
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
POINTER(c_size_t), POINTER(c_size_t),
] ]
lib.infiniopDequantize.restype = c_int32 lib.infiniopDequantizeAWQ.restype = c_int32
lib.infiniopDequantize.argtypes = [ lib.infiniopDequantizeAWQ.argtypes = [
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
c_void_p, c_void_p,
c_size_t, c_size_t,
...@@ -557,8 +557,8 @@ def dequantize_(lib): ...@@ -557,8 +557,8 @@ def dequantize_(lib):
c_void_p, c_void_p,
c_void_p, c_void_p,
] ]
lib.infiniopDestroyDequantizeDescriptor.restype = c_int32 lib.infiniopDestroyDequantizeAWQDescriptor.restype = c_int32
lib.infiniopDestroyDequantizeDescriptor.argtypes = [ lib.infiniopDestroyDequantizeAWQDescriptor.argtypes = [
infiniopOperatorDescriptor_t, infiniopOperatorDescriptor_t,
] ]
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment