Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
25258029
Unverified
Commit
25258029
authored
Dec 25, 2025
by
qinyiqun
Committed by
GitHub
Dec 25, 2025
Browse files
Issue/840: 英伟达Int8 Gemm (#841)
parent
12cde8eb
Changes
14
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
2496 additions
and
9 deletions
+2496
-9
include/infiniop/ops/int8_gemm.h
include/infiniop/ops/int8_gemm.h
+32
-0
src/infiniop/ops/scaled_mm/info.h
src/infiniop/ops/scaled_mm/info.h
+121
-0
src/infiniop/ops/scaled_mm/int8_gemm.h
src/infiniop/ops/scaled_mm/int8_gemm.h
+46
-0
src/infiniop/ops/scaled_mm/nvidia/epilogue_per_row_per_col_scale.h
...iop/ops/scaled_mm/nvidia/epilogue_per_row_per_col_scale.h
+308
-0
src/infiniop/ops/scaled_mm/nvidia/gemm_universal_base_compat.h
...nfiniop/ops/scaled_mm/nvidia/gemm_universal_base_compat.h
+354
-0
src/infiniop/ops/scaled_mm/nvidia/gemm_with_epilogue_visitor.h
...nfiniop/ops/scaled_mm/nvidia/gemm_with_epilogue_visitor.h
+487
-0
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_kernel.cuh
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_kernel.cuh
+693
-0
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
+115
-0
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cuh
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cuh
+7
-0
src/infiniop/ops/scaled_mm/operator.cc
src/infiniop/ops/scaled_mm/operator.cc
+90
-0
test/infiniop/libinfiniop/op_register.py
test/infiniop/libinfiniop/op_register.py
+39
-0
test/infiniop/libinfiniop/utils.py
test/infiniop/libinfiniop/utils.py
+1
-1
test/infiniop/scaled_mm_int8.py
test/infiniop/scaled_mm_int8.py
+192
-0
xmake/nvidia.lua
xmake/nvidia.lua
+11
-8
No files found.
include/infiniop/ops/int8_gemm.h
0 → 100644
View file @
25258029
#ifndef __INFINIOP_I8GEMM_API_H__
#define __INFINIOP_I8GEMM_API_H__
#include "../operator_descriptor.h"
typedef
InfiniopDescriptor
*
infiniopI8GemmDescriptor_t
;
__C
__export
infiniStatus_t
infiniopCreateI8GemmDescriptor
(
infiniopHandle_t
handle
,
infiniopI8GemmDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
bias_desc
,
infiniopTensorDescriptor_t
x_desc
,
infiniopTensorDescriptor_t
x_scale_desc
,
infiniopTensorDescriptor_t
weights_desc
,
infiniopTensorDescriptor_t
weights_scale_desc
);
__C
__export
infiniStatus_t
infiniopGetI8GemmWorkspaceSize
(
infiniopI8GemmDescriptor_t
desc
,
size_t
*
size
);
__C
__export
infiniStatus_t
infiniopI8Gemm
(
infiniopI8GemmDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
bias
,
const
void
*
x
,
const
void
*
x_scale
,
const
void
*
weights
,
const
void
*
weights_scale
,
void
*
stream
);
__C
__export
infiniStatus_t
infiniopDestroyI8GemmDescriptor
(
infiniopI8GemmDescriptor_t
desc
);
#endif
src/infiniop/ops/scaled_mm/info.h
0 → 100644
View file @
25258029
#ifndef __GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace
op
::
i8gemm
{
struct
BlasMatrix
{
int
ndim
;
int
batch
;
int
stride
;
int
rows
;
int
cols
;
int
row_stride
;
int
col_stride
;
static
utils
::
Result
<
BlasMatrix
>
create
(
infiniopTensorDescriptor_t
layout
)
{
BlasMatrix
ans
;
if
(
layout
->
ndim
()
==
2
)
{
ans
.
ndim
=
2
;
ans
.
batch
=
1
;
ans
.
stride
=
0
;
ans
.
rows
=
layout
->
dim
(
0
);
ans
.
cols
=
layout
->
dim
(
1
);
ans
.
row_stride
=
layout
->
stride
(
0
);
ans
.
col_stride
=
layout
->
stride
(
1
);
}
else
if
(
layout
->
ndim
()
==
3
)
{
ans
.
ndim
=
3
;
ans
.
batch
=
layout
->
dim
(
0
);
ans
.
stride
=
ans
.
batch
==
1
?
0
:
layout
->
stride
(
0
);
ans
.
rows
=
layout
->
dim
(
1
);
ans
.
cols
=
layout
->
dim
(
2
);
ans
.
row_stride
=
layout
->
stride
(
1
);
ans
.
col_stride
=
layout
->
stride
(
2
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
ans
.
row_stride
!=
1
&&
ans
.
col_stride
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
return
utils
::
Result
<
BlasMatrix
>
(
ans
);
}
bool
match_batch
(
int
_batch
)
const
{
return
batch
==
_batch
||
batch
==
1
;
}
void
transpose
()
{
std
::
swap
(
rows
,
cols
);
std
::
swap
(
row_stride
,
col_stride
);
}
int
ld
()
const
{
return
row_stride
==
1
?
col_stride
:
row_stride
;
}
};
enum
class
MatrixLayout
:
char
{
COL_MAJOR
,
ROW_MAJOR
,
};
class
I8GemmInfo
{
I8GemmInfo
()
=
default
;
public:
BlasMatrix
a_matrix
;
BlasMatrix
b_matrix
;
BlasMatrix
out_matrix
;
int
m
,
n
,
k
,
batch
;
static
utils
::
Result
<
I8GemmInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
MatrixLayout
layout
)
{
auto
a_matrix
=
BlasMatrix
::
create
(
a_desc
);
CHECK_RESULT
(
a_matrix
);
auto
b_matrix
=
BlasMatrix
::
create
(
b_desc
);
CHECK_RESULT
(
b_matrix
);
auto
out_matrix
=
BlasMatrix
::
create
(
out_desc
);
CHECK_RESULT
(
out_matrix
);
if
(
out_matrix
->
rows
!=
a_matrix
->
rows
||
out_matrix
->
cols
!=
b_matrix
->
cols
||
a_matrix
->
cols
!=
b_matrix
->
rows
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
auto
batch
=
out_matrix
->
batch
;
if
(
!
a_matrix
->
match_batch
(
batch
)
||
!
b_matrix
->
match_batch
(
batch
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
auto
m
=
out_matrix
->
rows
;
auto
n
=
out_matrix
->
cols
;
auto
k
=
a_matrix
->
cols
;
return
utils
::
Result
<
I8GemmInfo
>
(
I8GemmInfo
{
a_matrix
.
take
(),
b_matrix
.
take
(),
out_matrix
.
take
(),
m
,
n
,
k
,
batch
});
}
};
}
// namespace op::i8gemm
#endif // __I8GEMM_INFO_H__
src/infiniop/ops/scaled_mm/int8_gemm.h
0 → 100644
View file @
25258029
#ifndef __I8GEMM_H__
#define __I8GEMM_H__
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::i8gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
I8GemmInfo _info; \
infiniDtype_t _out_dtype; \
\
Descriptor(Opaque *opaque, I8GemmInfo info, \
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t bias_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t a_scale_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_scale_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *bias, const void *a, \
const void *a_scale, const void *b, \
const void *b_scale, void *stream) const; \
}; \
}
#endif // __I8GEMM_H__
src/infiniop/ops/scaled_mm/nvidia/epilogue_per_row_per_col_scale.h
0 → 100644
View file @
25258029
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
#pragma once
#include <cutlass/arch/memory.h>
#include <cutlass/numeric_conversion.h>
namespace
cutlass
{
namespace
epilogue
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
public:
using
ThreadblockShape
=
ThreadblockShape_
;
static
int
const
kThreadCount
=
ThreadCount
;
using
ScaleTileIterator
=
ScaleTileIterator_
;
using
OutputTileIterator
=
OutputTileIterator_
;
using
ElementwiseFunctor
=
ElementwiseFunctor_
;
static
int
const
kIterations
=
OutputTileIterator
::
kIterations
;
static
int
const
kElementsPerAccess
=
OutputTileIterator
::
kElementsPerAccess
;
using
ElementOutput
=
typename
OutputTileIterator
::
Element
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
using
ElementAccumulator
=
ElementAccumulator_
;
using
AlphaScaleElementType
=
typename
ScaleTileIterator
::
Element
;
using
ElementCompute
=
ElementCompute_
;
using
AccumulatorFragment
=
Array
<
ElementAccumulator
,
kElementsPerAccess
>
;
using
ComputeFragment
=
Array
<
ElementCompute_
,
kElementsPerAccess
>
;
using
OutputVector
=
Array
<
ElementOutput
,
kElementsPerAccess
>
;
static
int
const
kThreadsPerRow
=
OutputTileIterator
::
ThreadMap
::
Detail
::
kAccessWidth
;
static
bool
const
kHasMultiStepsInRow
=
(
OutputTileIterator
::
ThreadMap
::
Iterations
::
kColumn
>
1
);
/// Argument structure
struct
Arguments
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
Arguments
()
:
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_C
(
batch_stride_C_
),
batch_stride_D
(
batch_stride_D_
)
{}
};
struct
Params
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
{}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
)
:
elementwise
(
args
.
elementwise
),
batch_stride_alpha
(
args
.
batch_stride_alpha
),
batch_stride_C
(
args
.
batch_stride_C
),
batch_stride_D
(
args
.
batch_stride_D
)
{}
};
/// Shared storage
struct
SharedStorage
{};
private:
Params
const
&
params_
;
SharedStorage
&
shared_storage_
;
MatrixCoord
extent_
;
MatrixCoord
extent_real_
;
ElementwiseFunctor
elementwise_
;
bool
const
with_bias_
;
bool
const
per_token_quant_
;
bool
const
per_channel_quant_
;
AlphaScaleElementType
*
ptr_alpha_row_
;
AlphaScaleElementType
*
ptr_alpha_col_
;
ScaleTileIterator
iterator_alpha_col_
;
OutputTileIterator
iterator_C_
;
OutputTileIterator
iterator_D_
;
AlphaScaleElementType
element_alpha_row_
=
1.0
f
;
AlphaScaleElementType
element_alpha_col_
=
1.0
f
;
typename
ScaleTileIterator
::
Fragment
fragment_alpha_col_
;
typename
OutputTileIterator
::
Fragment
fragment_C_
;
typename
OutputTileIterator
::
Fragment
fragment_D_
;
ElementAccumulator
beta_
;
int
column_offset_
;
MatrixCoord
thread_offset_
;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
),
shared_storage_
(
shared_storage
),
extent_
(
problem_size
),
elementwise_
(
params
.
elementwise
),
with_bias_
(
with_bias
),
per_token_quant_
(
per_token_quant
),
per_channel_quant_
(
per_channel_quant
),
ptr_alpha_row_
(
ptr_alpha_row
),
ptr_alpha_col_
(
ptr_alpha_col
),
iterator_alpha_col_
(
params_alpha_col
,
ptr_alpha_col
,
problem_size
,
thread_idx
,
threadblock_offset
),
iterator_C_
(
params_C
,
ptr_C
,
problem_size
,
thread_idx
,
threadblock_offset
),
iterator_D_
(
params_D
,
ptr_D
,
problem_size
,
thread_idx
,
threadblock_offset
),
extent_real_
(
problem_size_real
)
{
if
(
!
per_channel_quant_
&&
(
ptr_alpha_col_
!=
nullptr
))
{
element_alpha_col_
=
*
ptr_alpha_col_
;
}
if
(
!
per_token_quant_
&&
(
ptr_alpha_row_
!=
nullptr
))
{
element_alpha_row_
=
*
ptr_alpha_row_
;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void
set_batch_index
(
int
batch_idx
)
{
iterator_alpha_col_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_alpha
);
iterator_C_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_C
);
iterator_D_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_D
);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void
begin_epilogue
()
{
if
(
per_channel_quant_
)
{
iterator_alpha_col_
.
load
(
fragment_alpha_col_
);
}
if
(
with_bias_
)
{
iterator_C_
.
load
(
fragment_C_
);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void
begin_step
(
int
step_idx
)
{
fragment_D_
.
clear
();
}
/// Called at the start of a row
CUTLASS_DEVICE
void
begin_row
(
int
row_idx
)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if
(
per_token_quant_
)
{
int
thread_offset_row
=
iterator_D_
.
thread_start_row
()
+
OutputTileIterator
::
ThreadMap
::
iteration_offset
(
row_idx
).
row
();
arch
::
global_load
<
AlphaScaleElementType
,
sizeof
(
AlphaScaleElementType
)
>
(
element_alpha_row_
,
ptr_alpha_row_
+
thread_offset_row
,
thread_offset_row
<
extent_
.
row
());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frag_idx
,
AccumulatorFragment
const
&
accum
)
{
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kElementsPerAccess
>
source_converter
;
ComputeFragment
result
=
source_converter
(
accum
);
if
(
per_channel_quant_
)
{
ComputeFragment
alpha_col
=
reinterpret_cast
<
ComputeFragment
*>
(
&
fragment_alpha_col_
)[
column_idx
];
result
=
per_token_channel_scale_accumulator_
(
result
,
alpha_col
,
element_alpha_row_
);
}
else
{
result
=
per_token_scale_accumulator_
(
result
,
element_alpha_col_
,
element_alpha_row_
);
}
if
(
with_bias_
)
{
NumericArrayConverter
<
ElementCompute
,
ElementOutput
,
kElementsPerAccess
>
bias_converter
;
OutputVector
bias
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_C_
)[
column_idx
];
result
=
bias_accumulator_
(
result
,
bias_converter
(
bias
));
}
// Convert to the output
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kElementsPerAccess
>
output_converter
;
OutputVector
&
output
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_D_
)[
frag_idx
];
output
=
output_converter
(
result
);
}
/// Called at the end of a row
CUTLASS_DEVICE
void
end_row
(
int
row_idx
)
{}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void
end_step
(
int
step_idx
)
{
iterator_D_
.
store
(
fragment_D_
);
++
iterator_D_
;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void
end_epilogue
()
{}
private:
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
[
i
]
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
bias_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
bias
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
OutputVector
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
+
bias
[
i
];
}
return
result
;
}
};
}
// namespace threadblock
}
// namespace epilogue
}
// namespace cutlass
src/infiniop/ops/scaled_mm/nvidia/gemm_universal_base_compat.h
0 → 100644
View file @
25258029
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/device_kernel.h>
#include <cutlass/trace.h>
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template
<
typename
GemmKernel_
>
class
GemmUniversalBaseCompat
{
public:
using
GemmKernel
=
GemmKernel_
;
using
ThreadblockShape
=
typename
GemmKernel
::
Mma
::
Shape
;
using
ElementA
=
typename
GemmKernel
::
ElementA
;
using
LayoutA
=
typename
GemmKernel
::
LayoutA
;
using
TensorRefA
=
TensorRef
<
ElementA
const
,
LayoutA
>
;
static
ComplexTransform
const
kTransformA
=
GemmKernel
::
kTransformA
;
using
ElementB
=
typename
GemmKernel
::
ElementB
;
using
LayoutB
=
typename
GemmKernel
::
LayoutB
;
using
TensorRefB
=
TensorRef
<
ElementB
const
,
LayoutB
>
;
static
ComplexTransform
const
kTransformB
=
GemmKernel
::
kTransformB
;
using
ElementC
=
typename
GemmKernel
::
ElementC
;
using
LayoutC
=
typename
GemmKernel
::
LayoutC
;
using
TensorRefC
=
TensorRef
<
ElementC
const
,
LayoutC
>
;
using
TensorRefD
=
TensorRef
<
ElementC
,
LayoutC
>
;
using
ElementAccumulator
=
typename
GemmKernel
::
Mma
::
Policy
::
Operator
::
ElementC
;
using
EpilogueOutputOp
=
typename
GemmKernel
::
EpilogueOutputOp
;
using
ThreadblockSwizzle
=
typename
GemmKernel
::
ThreadblockSwizzle
;
using
Operator
=
typename
GemmKernel
::
Operator
;
/// Argument structure
using
Arguments
=
typename
GemmKernel
::
Arguments
;
protected:
/// Kernel parameters object
typename
GemmKernel
::
Params
params_
;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static
void
get_grid_shape_
(
gemm
::
GemmCoord
&
grid_tiled_shape
,
int
&
gemm_k_size
,
Arguments
const
&
args
)
{
// Determine grid shape
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
gemm_k_size
=
args
.
problem_size
.
k
();
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat
()
{}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
uint32_t
const
kGridYZMax
=
((
1
<<
(
sizeof
(
uint16_t
)
*
8
))
-
1
);
if
(
!
(
grid
.
y
<=
kGridYZMax
&&
grid
.
z
<=
kGridYZMax
))
{
return
Status
::
kErrorInvalidProblem
;
}
return
GemmKernel
::
can_implement
(
args
);
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_workspace_size()"
);
size_t
workspace_bytes
=
0
;
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes
=
sizeof
(
ElementC
)
*
size_t
(
args
.
batch_stride_D
)
*
size_t
(
grid_tiled_shape
.
k
());
}
else
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
&&
grid_tiled_shape
.
k
()
>
1
)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes
=
sizeof
(
int
)
*
size_t
(
grid_tiled_shape
.
m
())
*
size_t
(
grid_tiled_shape
.
n
());
}
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
workspace_bytes
+=
GemmKernel
::
get_extra_workspace_size
(
args
,
grid_tiled_shape
);
return
workspace_bytes
;
}
/// Computes the grid shape
static
dim3
get_grid_shape
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_grid_shape()"
);
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
smem_capacity
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
if
(
smem_capacity
<
0
)
{
int
device_idx
=
0
;
result
=
cudaGetDevice
(
&
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
cudaDeviceProp
properties
;
result
=
cudaGetDeviceProperties
(
&
properties
,
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
smem_capacity
=
static_cast
<
int
>
(
properties
.
sharedMemPerMultiprocessor
);
}
int
occupancy
=
std
::
min
(
max_active_blocks
,
smem_capacity
/
smem_size
);
CUTLASS_TRACE_HOST
(
" occupancy: "
<<
occupancy
);
return
occupancy
;
}
CUTLASS_TRACE_HOST
(
" returning internal error"
);
return
-
1
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
if
(
workspace_bytes
)
{
if
(
!
workspace
)
{
CUTLASS_TRACE_HOST
(
" error: device workspace must not be null"
);
return
Status
::
kErrorWorkspaceNull
;
}
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
)
{
CUTLASS_TRACE_HOST
(
" clearing device workspace"
);
cudaError_t
result
=
cudaMemsetAsync
(
workspace
,
0
,
workspace_bytes
,
stream
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaMemsetAsync() returned error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
}
// Get CUDA grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
// Initialize the Params structure
params_
=
typename
GemmKernel
::
Params
(
args
,
grid_tiled_shape
,
gemm_k_size
,
static_cast
<
int
*>
(
workspace
));
// Specify shared memory capacity for kernel.
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
if
(
smem_size
>=
(
48
<<
10
))
{
cudaError_t
result
=
cudaFuncSetAttribute
(
Kernel
<
GemmKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
return
Status
::
kErrorInternal
;
}
}
return
Status
::
kSuccess
;
}
/// Lightweight update given a subset of arguments
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
&&
!
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
params_
.
update
(
args
,
workspace
);
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::run()"
);
//
// Configure grid and block dimensions
//
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
params_
.
grid_tiled_shape
);
dim3
block
(
GemmKernel
::
kThreadCount
,
1
,
1
);
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
//
// Launch kernel
//
CUTLASS_TRACE_HOST
(
" grid: ("
<<
grid
<<
"), block: ("
<<
block
<<
"), SMEM: "
<<
smem_size
<<
" bytes"
);
// Launch
cutlass
::
Kernel
<
GemmKernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params_
);
//
// Query for errors
//
cudaError_t
result
=
cudaGetLastError
();
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" grid launch failed with error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
stream
);
}
/// Runs the kernel using initialized state.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
status
==
Status
::
kSuccess
)
{
status
=
run
(
stream
);
}
return
status
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace device
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
src/infiniop/ops/scaled_mm/nvidia/gemm_with_epilogue_visitor.h
0 → 100644
View file @
25258029
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
#pragma once
#include <cutlass/complex.h>
#include <cutlass/cutlass.h>
#include <cutlass/fast_math.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/trace.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueVisitor
=
typename
Epilogue
::
Visitor
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
TensorRefA
=
TensorRef
<
ElementA
,
LayoutA
>
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
TensorRefB
=
TensorRef
<
ElementB
,
LayoutB
>
;
using
ElementCompute
=
typename
EpilogueVisitor
::
ElementCompute
;
using
LayoutAlphaCol
=
cutlass
::
layout
::
RowMajor
;
using
LayoutAlphaRow
=
cutlass
::
layout
::
ColumnMajor
;
using
TensorRefAlphaCol
=
TensorRef
<
ElementCompute
,
LayoutAlphaCol
>
;
using
TensorRefAlphaRow
=
TensorRef
<
ElementCompute
,
LayoutAlphaRow
>
;
using
ElementC
=
typename
EpilogueVisitor
::
ElementOutput
;
using
LayoutC
=
typename
Epilogue
::
Layout
;
using
TensorRefC
=
TensorRef
<
ElementC
,
LayoutC
>
;
static
ComplexTransform
const
kTransformA
=
Mma
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
Mma
::
kTransformB
;
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
using
EpilogueOutputOp
=
typename
Epilogue
::
Visitor
::
ElementwiseFunctor
;
// Define type so GemmUniversalBase doesn't complain
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
/// Split-K preserves splits that are 128b aligned
static
int
const
kSplitKAlignment
=
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
);
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
GemmUniversalMode
mode
;
GemmCoord
problem_size
;
int
batch_count
;
TensorRefA
ref_A
;
TensorRefB
ref_B
;
TensorRefAlphaCol
ref_alpha_col
;
TensorRefAlphaRow
ref_alpha_row
;
TensorRefC
ref_C
;
TensorRefC
ref_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
int64_t
batch_stride_D
;
typename
EpilogueVisitor
::
Arguments
epilogue_visitor
;
//
// Methods
//
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
/// constructs an arguments structure
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
GemmUniversalMode
::
kGemm
),
problem_size
(
problem_size_
),
batch_count
(
1
),
ref_A
(
ref_A_
),
ref_B
(
ref_B_
),
ref_alpha_col
(
ref_alpha_col_
),
ref_alpha_row
(
ref_alpha_row_
),
ref_C
(
ref_C_
),
ref_D
(
ref_D_
),
batch_stride_A
(
0
),
batch_stride_B
(
0
),
batch_stride_D
(
0
),
epilogue_visitor
(
epilogue_visitor_
)
{}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
cutlass
::
gemm
::
GemmCoord
problem_size
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
swizzle_log_tile
;
typename
Mma
::
IteratorA
::
Params
params_A
;
typename
Mma
::
IteratorB
::
Params
params_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_row
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_C
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_D
;
GemmUniversalMode
mode
;
int
batch_count
;
int
gemm_k_size
;
void
*
ptr_A
;
void
*
ptr_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_row
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
typename
EpilogueVisitor
::
Params
epilogue_visitor
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
swizzle_log_tile
(
0
),
params_A
(
0
),
params_B
(
0
),
params_alpha_col
(
0
),
params_C
(
0
),
params_D
(
0
),
batch_count
(
0
),
gemm_k_size
(
0
),
mode
(
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
),
ptr_A
(
nullptr
),
ptr_B
(
nullptr
),
ptr_alpha_col
(
nullptr
),
ptr_alpha_row
(
nullptr
),
ptr_C
(
nullptr
),
ptr_D
(
nullptr
),
batch_stride_A
(
0
),
batch_stride_B
(
0
)
{}
Params
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape_
,
int
gemm_k_size_
,
int
*
workspace_
)
:
problem_size
(
args
.
problem_size
),
swizzle_log_tile
(
0
),
params_A
(
args
.
ref_A
.
layout
()),
params_B
(
args
.
ref_B
.
layout
()),
params_alpha_col
(
args
.
ref_alpha_col
.
layout
()),
params_alpha_row
(
args
.
ref_alpha_col
.
layout
()),
params_C
(
args
.
ref_C
.
layout
()),
params_D
(
args
.
ref_D
.
layout
()),
mode
(
args
.
mode
),
batch_count
(
args
.
batch_count
),
gemm_k_size
(
args
.
problem_size
.
k
()),
ptr_A
(
args
.
ref_A
.
data
()),
ptr_B
(
args
.
ref_B
.
data
()),
ptr_alpha_col
(
args
.
ref_alpha_col
.
data
()),
ptr_alpha_row
(
args
.
ref_alpha_row
.
data
()),
ptr_C
(
args
.
ref_C
.
data
()),
ptr_D
(
args
.
ref_D
.
data
()),
batch_stride_A
(
args
.
batch_stride_A
),
batch_stride_B
(
args
.
batch_stride_B
),
epilogue_visitor
(
args
.
epilogue_visitor
)
{
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
swizzle_log_tile
=
threadblock_swizzle
.
get_log_tile
(
grid_tiled_shape
);
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
Mma
::
SharedStorage
main_loop
;
struct
{
typename
Epilogue
::
SharedStorage
epilogue
;
typename
EpilogueVisitor
::
SharedStorage
visitor
;
}
epilogue
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
CUTLASS_TRACE_HOST
(
"GemmWithEpilogueVisitor::can_implement()"
);
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
OutputTileIterator
::
kElementsPerAccess
;
bool
isAMisaligned
=
false
;
bool
isBMisaligned
=
false
;
bool
isCMisaligned
=
false
;
if
(
platform
::
is_same
<
LayoutA
,
layout
::
RowMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
if
(
platform
::
is_same
<
LayoutC
,
layout
::
RowMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
if
(
isAMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for A operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isBMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for B operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isCMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for C operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
CUTLASS_TRACE_HOST
(
" returning kSuccess"
);
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
return
can_implement
(
args
.
problem_size
);
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void
run_kernel_
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
// Compute threadblock location
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// Early exit if CTA is out of range
if
(
params
.
grid_tiled_shape
.
m
()
<=
threadblock_tile_offset
.
m
()
||
params
.
grid_tiled_shape
.
n
()
<=
threadblock_tile_offset
.
n
())
{
return
;
}
int
offset_k
=
0
;
int
problem_size_k
=
params
.
problem_size
.
k
();
ElementA
*
ptr_A
=
static_cast
<
ElementA
*>
(
params
.
ptr_A
);
ElementB
*
ptr_B
=
static_cast
<
ElementB
*>
(
params
.
ptr_B
);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
||
params
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
if
(
threadblock_tile_offset
.
k
()
+
1
<
params
.
grid_tiled_shape
.
k
())
{
problem_size_k
=
(
threadblock_tile_offset
.
k
()
+
1
)
*
params
.
gemm_k_size
;
}
offset_k
=
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
)
{
ptr_A
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_A
;
ptr_B
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_B
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
ptr_A
=
static_cast
<
ElementA
*
const
*>
(
params
.
ptr_A
)[
threadblock_tile_offset
.
k
()];
ptr_B
=
static_cast
<
ElementB
*
const
*>
(
params
.
ptr_B
)[
threadblock_tile_offset
.
k
()];
}
#endif
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
offset_k
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
offset_k
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
};
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size_k
-
offset_k
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Compute threadblock-scoped matrix multiply-add
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
//
// Construct the epilogue visitor
//
bool
with_bias
=
true
;
if
(
params
.
ptr_C
==
nullptr
)
{
with_bias
=
false
;
}
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor
.
set_k_partition
(
threadblock_tile_offset
.
k
(),
params
.
grid_tiled_shape
.
k
());
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
||
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
epilogue_visitor
.
set_batch_index
(
threadblock_tile_offset
.
k
());
}
// Construct the epilogue
Epilogue
epilogue
(
shared_storage
.
epilogue
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
epilogue
(
epilogue_visitor
,
accumulators
);
}
template
<
typename
CompilationArch
>
CUTLASS_DEVICE
void
run_kernel
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
if
constexpr
(
platform
::
is_same
<
ArchTag
,
CompilationArch
>::
value
)
{
run_kernel_
(
params
,
shared_storage
);
}
else
{
CUTLASS_NOT_IMPLEMENTED
();
}
}
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
run_kernel
<
ArchTag
>
(
params
,
shared_storage
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_kernel.cuh
0 → 100644
View file @
25258029
This diff is collapsed.
Click to expand it.
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
0 → 100644
View file @
25258029
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#include "int8_gemm_kernel.cuh"
#include "int8_gemm_nvidia.cuh"
namespace
op
::
i8gemm
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
inline
int
getSMVersion
()
{
int
device
{
-
1
};
CHECK_CUDA
(
cudaGetDevice
(
&
device
));
int
sm_major
=
0
;
int
sm_minor
=
0
;
CHECK_CUDA
(
cudaDeviceGetAttribute
(
&
sm_major
,
cudaDevAttrComputeCapabilityMajor
,
device
));
CHECK_CUDA
(
cudaDeviceGetAttribute
(
&
sm_minor
,
cudaDevAttrComputeCapabilityMinor
,
device
));
return
sm_major
*
10
+
sm_minor
;
}
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
bias_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
a_scale_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
b_scale_desc
)
{
auto
handle
=
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
);
auto
result
=
I8GemmInfo
::
create
(
out_desc
,
a_desc
,
b_desc
,
MatrixLayout
::
COL_MAJOR
);
CHECK_RESULT
(
result
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
handle
->
internal
()},
result
.
take
(),
0
,
dtype
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
bias
,
const
void
*
a
,
const
void
*
a_scale
,
const
void
*
b
,
const
void
*
b_scale
,
void
*
stream
)
const
{
auto
sm_version
=
getSMVersion
();
if
(
sm_version
>=
75
&&
sm_version
<
80
)
{
CHECK_DTYPE
(
this
->
_out_dtype
,
INFINI_DTYPE_F16
);
sm75_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
if
(
sm_version
>=
80
&&
sm_version
<
90
)
{
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if
(
sm_version
==
86
||
sm_version
==
89
)
{
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm89_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm89_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
}
else
{
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm80_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm80_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
}
}
else
if
(
sm_version
==
90
)
{
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// cutlass 3.x
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm90_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm90_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
#else
// // fallback to cutlass 2.x
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm80_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm80_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
#endif
}
else
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::i8gemm::nvidia
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cuh
0 → 100644
View file @
25258029
#ifndef __INT8_GEMM_NVIDIA_API_H__
#define __INT8_GEMM_NVIDIA_API_H__
#include "../int8_gemm.h"
DESCRIPTOR
(
nvidia
)
#endif // __INT8_GEMM_NVIDIA_API_H__
src/infiniop/ops/scaled_mm/operator.cc
0 → 100644
View file @
25258029
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/int8_gemm.h"
#if defined(ENABLE_NVIDIA_API)
#include "nvidia/int8_gemm_nvidia.cuh"
#endif
__C
infiniStatus_t
infiniopCreateI8GemmDescriptor
(
infiniopHandle_t
handle
,
infiniopI8GemmDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
bias_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
a_scale_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
b_scale_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::i8gemm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
bias_desc, \
a_desc, \
a_scale_desc, \
b_desc, \
b_scale_desc);
switch
(
handle
->
device
)
{
#ifdef ENABLE_NVIDIA_API
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetI8GemmWorkspaceSize
(
infiniopI8GemmDescriptor_t
desc
,
size_t
*
size
)
{
switch
(
desc
->
device_type
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#ifdef ENABLE_NVIDIA_API
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
}
__C
infiniStatus_t
infiniopI8Gemm
(
infiniopI8GemmDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
bias
,
const
void
*
a
,
const
void
*
a_scale
,
const
void
*
b
,
const
void
*
b_scale
,
void
*
stream
)
{
#define CACULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, bias, a, a_scale, b, b_scale, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
CACULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CACULATE
}
__C
infiniStatus_t
infiniopDestroyI8GemmDescriptor
(
infiniopI8GemmDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_NVIDIA_API
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
}
test/infiniop/libinfiniop/op_register.py
View file @
25258029
...
...
@@ -938,3 +938,42 @@ def tanh_(lib):
lib
.
infiniopDestroyTanhDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
@
OpRegister
.
operator
def
scaled_mm_int8_
(
lib
):
lib
.
infiniopCreateI8GemmDescriptor
.
restype
=
c_int32
lib
.
infiniopCreateI8GemmDescriptor
.
argtypes
=
[
infiniopHandle_t
,
POINTER
(
infiniopOperatorDescriptor_t
),
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
infiniopTensorDescriptor_t
,
]
lib
.
infiniopGetI8GemmWorkspaceSize
.
restype
=
c_int32
lib
.
infiniopGetI8GemmWorkspaceSize
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
POINTER
(
c_size_t
),
]
lib
.
infiniopI8Gemm
.
restype
=
c_int32
lib
.
infiniopI8Gemm
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
c_void_p
,
c_size_t
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
c_void_p
,
]
lib
.
infiniopDestroyI8GemmDescriptor
.
restype
=
c_int32
lib
.
infiniopDestroyI8GemmDescriptor
.
argtypes
=
[
infiniopOperatorDescriptor_t
,
]
test/infiniop/libinfiniop/utils.py
View file @
25258029
...
...
@@ -336,7 +336,7 @@ def rearrange_tensor(tensor, new_strides):
torch
.
float32
,
torch
.
float64
,
]:
new_tensor
.
view
(
-
1
).
index_add_
(
0
,
new_positions
,
tensor
.
view
(
-
1
))
new_tensor
.
view
(
-
1
).
index_add_
(
0
,
new_positions
,
tensor
.
contiguous
().
view
(
-
1
))
elif
tensor
.
dtype
in
[
torch
.
uint16
,
torch
.
uint32
,
torch
.
uint64
]:
new_tensor_int64
=
new_tensor
.
to
(
dtype
=
torch
.
int64
)
tensor_int64
=
tensor
.
to
(
dtype
=
torch
.
int64
)
...
...
test/infiniop/scaled_mm_int8.py
0 → 100644
View file @
25258029
import
torch
import
ctypes
from
ctypes
import
c_uint64
from
libinfiniop
import
(
LIBINFINIOP
,
TestTensor
,
get_test_devices
,
check_error
,
test_operator
,
get_args
,
debug
,
get_tolerance
,
profile_operation
,
TestWorkspace
,
InfiniDtype
,
InfiniDtypeNames
,
InfiniDeviceNames
,
infiniopOperatorDescriptor_t
,
)
from
enum
import
Enum
,
auto
# ==============================================================================
# Configuration (Internal Use Only)
# ==============================================================================
# These are not meant to be imported from other modules
_TEST_CASES_
=
[
# x_shape, w_shape, y_shape, alpha, beta
# ((8, 8), (8, 8), False, (8, 8), 1.0, 0.0),
((
128
,
512
),
(
512
,
1024
),
True
,
(
128
,
1024
),
1.0
,
0.0
),
# ((128, 128), (128, 128), False, (128, 128), 2.0, 1.0),
((
256
,
1024
),
(
1024
,
2048
),
True
,
(
256
,
2048
),
1.0
,
1.0
),
((
1024
,
2048
),
(
2048
,
1024
),
True
,
(
1024
,
1024
),
1.0
,
0.0
),
]
class
Inplace
(
Enum
):
OUT_OF_PLACE
=
auto
()
INPLACE
=
auto
()
# Inplace options applied for each test case in _TEST_CASES_
_INPLACE
=
[
# Inplace.OUT_OF_PLACE,
Inplace
.
INPLACE
,
]
_TEST_CASES
=
[
test_case
+
(
inplace_item
,)
for
test_case
in
_TEST_CASES_
for
inplace_item
in
_INPLACE
]
# Data types used for testing
_TENSOR_DTYPES
=
[
InfiniDtype
.
BF16
,
InfiniDtype
.
F16
]
# Tolerance map for different data types
_TOLERANCE_MAP
=
{
InfiniDtype
.
F16
:
{
"atol"
:
3e-1
,
"rtol"
:
1e-2
},
InfiniDtype
.
BF16
:
{
"atol"
:
3e-1
,
"rtol"
:
1e-2
},
}
DEBUG
=
False
PROFILE
=
False
NUM_PRERUN
=
10
NUM_ITERATIONS
=
1000
def
to_int8
(
tensor
:
torch
.
Tensor
)
->
torch
.
Tensor
:
return
torch
.
round
(
tensor
.
clamp
(
min
=-
128
,
max
=
127
)).
to
(
dtype
=
torch
.
int8
)
def
torch_scaled_mm
(
a
,
b
,
scale_a
,
scale_b
,
out_dtype
,
bias
):
o
=
torch
.
matmul
(
a
.
to
(
torch
.
float32
),
b
.
to
(
torch
.
float32
))
if
bias
is
not
None
:
o
=
o
.
to
(
torch
.
float32
)
*
scale_a
.
view
(
-
1
,
1
)
*
scale_b
.
view
(
1
,
-
1
)
+
bias
else
:
o
=
o
.
to
(
torch
.
float32
)
*
scale_a
.
view
(
-
1
,
1
)
*
scale_b
.
view
(
1
,
-
1
)
return
o
.
to
(
out_dtype
)
def
test
(
handle
,
device
,
x_shape
,
w_shape
,
symmetric
,
y_shape
,
alpha
,
beta
,
inplace
=
Inplace
.
OUT_OF_PLACE
,
dtype
=
InfiniDtype
.
BF16
,
sync
=
None
,
):
print
(
f
"Testing Linear on
{
InfiniDeviceNames
[
device
]
}
with x_shape:
{
x_shape
}
, w_shape:
{
w_shape
}
, symmetric:
{
symmetric
}
, alpha:
{
alpha
}
, beta:
{
beta
}
, inplace:
{
inplace
}
dtype:
{
InfiniDtypeNames
[
dtype
]
}
"
)
M
,
K
=
x_shape
N
=
w_shape
[
1
]
x_packed
=
to_int8
(
torch
.
randn
((
M
,
K
),
device
=
"cuda"
)
*
5
)
weights
=
to_int8
(
torch
.
randn
((
N
,
K
),
device
=
"cuda"
).
t
()
*
5
)
x_scale
=
torch
.
randn
((
M
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
weights_scale
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float32
)
bias
=
torch
.
randn
((
N
,),
device
=
"cuda"
,
dtype
=
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
)
*
10
ans
=
torch_scaled_mm
(
x_packed
,
weights
,
x_scale
,
weights_scale
,
torch
.
float16
if
dtype
==
InfiniDtype
.
F16
else
torch
.
bfloat16
,
bias
=
bias
)
x_packed
=
TestTensor
(
(
M
,
K
),
x_packed
.
stride
(),
InfiniDtype
.
I8
,
device
,
mode
=
"manual"
,
set_tensor
=
x_packed
)
x_scale
=
TestTensor
(
(
M
,),
x_scale
.
stride
(),
InfiniDtype
.
F32
,
device
,
mode
=
"manual"
,
set_tensor
=
x_scale
)
weights
=
TestTensor
(
(
K
,
N
),
weights
.
stride
(),
InfiniDtype
.
I8
,
device
,
mode
=
"manual"
,
set_tensor
=
weights
)
weights_scale
=
TestTensor
(
(
N
,),
weights_scale
.
stride
(),
InfiniDtype
.
F32
,
device
,
mode
=
"manual"
,
set_tensor
=
weights_scale
)
y
=
TestTensor
(
y_shape
,
None
,
dtype
,
device
)
bias
=
TestTensor
((
N
,),
bias
.
stride
(),
dtype
,
device
,
mode
=
"manual"
,
set_tensor
=
bias
)
descriptor
=
infiniopOperatorDescriptor_t
()
check_error
(
LIBINFINIOP
.
infiniopCreateI8GemmDescriptor
(
handle
,
ctypes
.
byref
(
descriptor
),
y
.
descriptor
,
bias
.
descriptor
,
x_packed
.
descriptor
,
x_scale
.
descriptor
,
weights
.
descriptor
,
weights_scale
.
descriptor
,
)
)
workspace_size
=
c_uint64
(
0
)
check_error
(
LIBINFINIOP
.
infiniopGetI8GemmWorkspaceSize
(
descriptor
,
ctypes
.
byref
(
workspace_size
)
)
)
workspace
=
TestWorkspace
(
workspace_size
.
value
,
x_packed
.
device
)
def
lib_linear
():
check_error
(
LIBINFINIOP
.
infiniopI8Gemm
(
descriptor
,
workspace
.
data
(),
workspace_size
.
value
,
y
.
data
(),
bias
.
data
(),
x_packed
.
data
(),
x_scale
.
data
(),
weights
.
data
(),
weights_scale
.
data
(),
None
,
)
)
lib_linear
()
if
sync
is
not
None
:
sync
()
atol
,
rtol
=
get_tolerance
(
_TOLERANCE_MAP
,
dtype
)
if
DEBUG
:
debug
(
y
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
assert
torch
.
allclose
(
y
.
actual_tensor
(),
ans
,
atol
=
atol
,
rtol
=
rtol
)
# Profiling workflow
if
PROFILE
:
# fmt: off
profile_operation
(
"PyTorch"
,
lambda
:
linearFunction
(
y
.
torch_tensor
(),
bias
.
torch_tensor
(),
x
.
torch_tensor
(),
w
.
torch_tensor
(),
alpha
,
beta
),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
profile_operation
(
" lib"
,
lambda
:
lib_linear
(),
device
,
NUM_PRERUN
,
NUM_ITERATIONS
)
# fmt: on
check_error
(
LIBINFINIOP
.
infiniopDestroyI8GemmDescriptor
(
descriptor
))
if
__name__
==
"__main__"
:
args
=
get_args
()
# Configure testing options
DEBUG
=
args
.
debug
PROFILE
=
args
.
profile
NUM_PRERUN
=
args
.
num_prerun
NUM_ITERATIONS
=
args
.
num_iterations
for
device
in
get_test_devices
(
args
):
test_operator
(
device
,
test
,
_TEST_CASES
,
_TENSOR_DTYPES
)
print
(
"
\033
[92mTest passed!
\033
[0m"
)
xmake/nvidia.lua
View file @
25258029
...
...
@@ -4,11 +4,9 @@ if CUDNN_ROOT ~= nil then
end
local
CUTLASS_ROOT
=
os.getenv
(
"CUTLASS_ROOT"
)
or
os.getenv
(
"CUTLASS_HOME"
)
or
os.getenv
(
"CUTLASS_PATH"
)
local
CUTE_ROOT
=
os.getenv
(
"CUTE_ROOT"
)
or
os.getenv
(
"CUTE_HOME"
)
or
os.getenv
(
"CUTE_PATH"
)
if
CUTLASS_ROOT
~=
nil
then
add_includedirs
(
CUTLASS_ROOT
)
add_includedirs
(
CUTE_ROOT
)
end
target
(
"infiniop-nvidia"
)
...
...
@@ -22,7 +20,6 @@ target("infiniop-nvidia")
if
has_config
(
"cudnn"
)
then
add_links
(
"cudnn"
)
end
add_cugencodes
(
"native"
)
on_load
(
function
(
target
)
import
(
"lib.detect.find_tool"
)
...
...
@@ -36,11 +33,6 @@ target("infiniop-nvidia")
target
:
add
(
"linkdirs"
,
path
.
directory
(
path
.
directory
(
nvcc_path
))
..
"/lib64/stubs"
)
target
:
add
(
"links"
,
"cuda"
)
local
cuda_arch
=
get_config
(
"cuda_arch"
)
if
cuda_arch
~=
nil
then
target
:
add
(
"cu-cxxflags"
,
"-arch="
,
cuda_arch
)
end
end
end
)
...
...
@@ -65,6 +57,17 @@ target("infiniop-nvidia")
add_cuflags
(
"-Xcompiler=-Wno-error=deprecated-declarations"
)
local
arch_opt
=
get_config
(
"cuda_arch"
)
if
arch_opt
and
type
(
arch_opt
)
==
"string"
then
for
_
,
arch
in
ipairs
(
arch_opt
:
split
(
","
))
do
arch
=
arch
:
trim
()
local
compute
=
arch
:
gsub
(
"sm_"
,
"compute_"
)
add_cuflags
(
"-gencode=arch="
..
compute
..
",code="
..
arch
)
end
else
add_cugencodes
(
"native"
)
end
set_languages
(
"cxx17"
)
add_files
(
"../src/infiniop/devices/nvidia/*.cu"
,
"../src/infiniop/ops/*/nvidia/*.cu"
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment