Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
jerrrrry
infinicore
Commits
8d09630a
Unverified
Commit
8d09630a
authored
Feb 11, 2026
by
gongchensu
Committed by
GitHub
Feb 11, 2026
Browse files
Merge branch 'demo131' into Issue/862
parents
ab52dead
012df56c
Changes
387
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2778 additions
and
5 deletions
+2778
-5
src/infiniop/ops/scaled_mm/cuda/per_channel_dequant_int8.cuh
src/infiniop/ops/scaled_mm/cuda/per_channel_dequant_int8.cuh
+80
-0
src/infiniop/ops/scaled_mm/info.h
src/infiniop/ops/scaled_mm/info.h
+121
-0
src/infiniop/ops/scaled_mm/int8_gemm.h
src/infiniop/ops/scaled_mm/int8_gemm.h
+51
-0
src/infiniop/ops/scaled_mm/nvidia/epilogue_per_row_per_col_scale.h
...iop/ops/scaled_mm/nvidia/epilogue_per_row_per_col_scale.h
+308
-0
src/infiniop/ops/scaled_mm/nvidia/gemm_universal_base_compat.h
...nfiniop/ops/scaled_mm/nvidia/gemm_universal_base_compat.h
+354
-0
src/infiniop/ops/scaled_mm/nvidia/gemm_with_epilogue_visitor.h
...nfiniop/ops/scaled_mm/nvidia/gemm_with_epilogue_visitor.h
+487
-0
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_kernel.cuh
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_kernel.cuh
+681
-0
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
+206
-0
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cuh
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cuh
+7
-0
src/infiniop/ops/scaled_mm/operator.cc
src/infiniop/ops/scaled_mm/operator.cc
+102
-0
src/infiniop/ops/sigmoid/operator.cc
src/infiniop/ops/sigmoid/operator.cc
+14
-1
src/infiniop/ops/silu/operator.cc
src/infiniop/ops/silu/operator.cc
+14
-1
src/infiniop/ops/silu_and_mul/info.h
src/infiniop/ops/silu_and_mul/info.h
+54
-0
src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h
src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h
+8
-0
src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu
src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu
+123
-0
src/infiniop/ops/silu_and_mul/operator.cc
src/infiniop/ops/silu_and_mul/operator.cc
+79
-0
src/infiniop/ops/silu_and_mul/silu_and_mul.h
src/infiniop/ops/silu_and_mul/silu_and_mul.h
+46
-0
src/infiniop/ops/softmax/operator.cc
src/infiniop/ops/softmax/operator.cc
+13
-1
src/infiniop/ops/softplus/operator.cc
src/infiniop/ops/softplus/operator.cc
+17
-1
src/infiniop/ops/sub/operator.cc
src/infiniop/ops/sub/operator.cc
+13
-1
No files found.
src/infiniop/ops/scaled_mm/cuda/per_channel_dequant_int8.cuh
0 → 100644
View file @
8d09630a
#ifndef __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
#define __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
/**
* @brief Symmetric dequantization kernel for post-processing quantized matrix multiplication
*
* This kernel performs symmetric dequantization on the packed integer output from
* a quantized matrix multiplication. It converts integer results back to floating-point
* values by applying per-tensor scaling factors from both input and weight tensors,
* then adds bias terms.
*
* The dequantization formula is:
* y = x_scale * w_scale * y_packed + bias
*
* @tparam Tdata Output data type (typically bfloat16 or half)
*
* @param[out] y Output tensor after dequantization
* Shape: [M, N], Data type: Tdata
*
* @param[in] y_packed Packed integer output from quantized matmul
* Shape: [M, N], Data type: int32_t
* Contains integer results of: x_packed[i,:] * w_packed[:,j]
*
* @param[in] bias Bias tensor to add after dequantization
* Shape: [N], Data type: Tdata
* Broadcasted across all rows
*
* @param[in] x_packed Packed quantized input tensor (not directly used here)
* Shape: [M, K], Data type: int8_t
* Included for context of the computation pipeline
*
* @param[in] x_scale Per-tensor scaling factors for input
* Shape: [M], Data type: float
* One scale value per input row
*
* @param[in] w_packed Packed quantized weight tensor (not directly used here)
* Shape: [K, N], Data type: int8_t
* Included for context of the computation pipeline
*
* @param[in] w_scale Per-tensor scaling factors for weights
* Shape: [N], Data type: float
* One scale value per output column
*
* @param[in] M Batch size / number of input rows
*
* @param[in] K Inner dimension of matrix multiplication
*
* @param[in] N Output dimension / number of output columns
*
* @note This kernel assumes symmetric quantization (zero-point = 0)
* @note Each thread processes one element of the output matrix
* @note Grid and block dimensions should be configured to cover [M, N] output space
*/
template
<
typename
Tdata
>
__device__
void
postSymKernel
(
Tdata
*
y
,
int32_t
*
y_packed
,
const
Tdata
*
bias
,
const
int8_t
*
x_packed
,
const
float
*
x_scale
,
const
int8_t
*
w_packed
,
const
float
*
w_scale
,
int
M
,
int
K
,
int
N
)
{
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row
>=
M
||
col
>=
N
)
{
return
;
}
int
idx
=
row
*
N
+
col
;
float
output1
=
x_scale
[
row
]
*
w_scale
[
col
]
*
((
float
)
y_packed
[
idx
]);
float
output
=
output1
+
(
float
)
bias
[
col
];
y
[
idx
]
=
static_cast
<
Tdata
>
(
output
);
}
// y = x_scale * w_scale * y_packed
template
<
typename
Tdata
>
__device__
void
postSymKernel
(
Tdata
*
y
,
int32_t
*
y_packed
,
const
int8_t
*
x_packed
,
const
float
*
x_scale
,
const
int8_t
*
w_packed
,
const
float
*
w_scale
,
int
M
,
int
K
,
int
N
)
{
int
row
=
blockIdx
.
y
*
blockDim
.
y
+
threadIdx
.
y
;
int
col
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
row
>=
M
||
col
>=
N
)
{
return
;
}
int
idx
=
row
*
N
+
col
;
float
output
=
x_scale
[
row
]
*
w_scale
[
col
]
*
((
float
)
y_packed
[
idx
]);
y
[
idx
]
=
static_cast
<
Tdata
>
(
output
);
}
#endif // __PER_CHANNEL_DEQUANT_INT8_KERNEL_CUH__
src/infiniop/ops/scaled_mm/info.h
0 → 100644
View file @
8d09630a
#ifndef __GEMM_INFO_H__
#define __I8GEMM_INFO_H__
#include "../../../utils.h"
#include "../../operator.h"
#include "../../tensor.h"
#include <algorithm>
namespace
op
::
i8gemm
{
struct
BlasMatrix
{
int
ndim
;
int
batch
;
int
stride
;
int
rows
;
int
cols
;
int
row_stride
;
int
col_stride
;
static
utils
::
Result
<
BlasMatrix
>
create
(
infiniopTensorDescriptor_t
layout
)
{
BlasMatrix
ans
;
if
(
layout
->
ndim
()
==
2
)
{
ans
.
ndim
=
2
;
ans
.
batch
=
1
;
ans
.
stride
=
0
;
ans
.
rows
=
layout
->
dim
(
0
);
ans
.
cols
=
layout
->
dim
(
1
);
ans
.
row_stride
=
layout
->
stride
(
0
);
ans
.
col_stride
=
layout
->
stride
(
1
);
}
else
if
(
layout
->
ndim
()
==
3
)
{
ans
.
ndim
=
3
;
ans
.
batch
=
layout
->
dim
(
0
);
ans
.
stride
=
ans
.
batch
==
1
?
0
:
layout
->
stride
(
0
);
ans
.
rows
=
layout
->
dim
(
1
);
ans
.
cols
=
layout
->
dim
(
2
);
ans
.
row_stride
=
layout
->
stride
(
1
);
ans
.
col_stride
=
layout
->
stride
(
2
);
}
else
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
if
(
ans
.
row_stride
!=
1
&&
ans
.
col_stride
!=
1
)
{
return
INFINI_STATUS_BAD_TENSOR_STRIDES
;
}
return
utils
::
Result
<
BlasMatrix
>
(
ans
);
}
bool
match_batch
(
int
_batch
)
const
{
return
batch
==
_batch
||
batch
==
1
;
}
void
transpose
()
{
std
::
swap
(
rows
,
cols
);
std
::
swap
(
row_stride
,
col_stride
);
}
int
ld
()
const
{
return
row_stride
==
1
?
col_stride
:
row_stride
;
}
};
enum
class
MatrixLayout
:
char
{
COL_MAJOR
,
ROW_MAJOR
,
};
class
I8GemmInfo
{
I8GemmInfo
()
=
default
;
public:
BlasMatrix
a_matrix
;
BlasMatrix
b_matrix
;
BlasMatrix
out_matrix
;
int
m
,
n
,
k
,
batch
;
static
utils
::
Result
<
I8GemmInfo
>
create
(
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
b_desc
,
MatrixLayout
layout
)
{
auto
a_matrix
=
BlasMatrix
::
create
(
a_desc
);
CHECK_RESULT
(
a_matrix
);
auto
b_matrix
=
BlasMatrix
::
create
(
b_desc
);
CHECK_RESULT
(
b_matrix
);
auto
out_matrix
=
BlasMatrix
::
create
(
out_desc
);
CHECK_RESULT
(
out_matrix
);
if
(
out_matrix
->
rows
!=
a_matrix
->
rows
||
out_matrix
->
cols
!=
b_matrix
->
cols
||
a_matrix
->
cols
!=
b_matrix
->
rows
)
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
auto
batch
=
out_matrix
->
batch
;
if
(
!
a_matrix
->
match_batch
(
batch
)
||
!
b_matrix
->
match_batch
(
batch
))
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
auto
m
=
out_matrix
->
rows
;
auto
n
=
out_matrix
->
cols
;
auto
k
=
a_matrix
->
cols
;
return
utils
::
Result
<
I8GemmInfo
>
(
I8GemmInfo
{
a_matrix
.
take
(),
b_matrix
.
take
(),
out_matrix
.
take
(),
m
,
n
,
k
,
batch
});
}
};
}
// namespace op::i8gemm
#endif // __I8GEMM_INFO_H__
src/infiniop/ops/scaled_mm/int8_gemm.h
0 → 100644
View file @
8d09630a
#ifndef __I8GEMM_H__
#define __I8GEMM_H__
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::i8gemm::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
size_t _workspace_size; \
I8GemmInfo _info; \
infiniDtype_t _out_dtype; \
\
Descriptor(Opaque *opaque, I8GemmInfo info, \
size_t workspace_size, \
infiniDtype_t out_dtype, \
infiniDevice_t device_type, int device_id) \
: InfiniopDescriptor{device_type, device_id}, _out_dtype(out_dtype), \
_opaque(opaque), _info(info), _workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t minWorkspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, Descriptor **desc_ptr, \
infiniopTensorDescriptor_t out_desc, \
infiniopTensorDescriptor_t bias_desc, \
infiniopTensorDescriptor_t a_desc, \
infiniopTensorDescriptor_t a_scale_desc, \
infiniopTensorDescriptor_t b_desc, \
infiniopTensorDescriptor_t b_scale_desc); \
template <unsigned int BLOCK_SIZE, typename Tdata> \
infiniStatus_t launchKernel(const I8GemmInfo &info, Tdata *y, \
const Tdata *bias, const int8_t *x_packed, \
const float *x_scale, const int8_t *w_packed, \
const float *w_scale, void *stream, void *workspace) const; \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *out, const void *bias, const void *a, \
const void *a_scale, const void *b, \
const void *b_scale, void *stream) const; \
}; \
}
#endif // __I8GEMM_H__
src/infiniop/ops/scaled_mm/nvidia/epilogue_per_row_per_col_scale.h
0 → 100644
View file @
8d09630a
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
#pragma once
#include <cutlass/arch/memory.h>
#include <cutlass/numeric_conversion.h>
namespace
cutlass
{
namespace
epilogue
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
public:
using
ThreadblockShape
=
ThreadblockShape_
;
static
int
const
kThreadCount
=
ThreadCount
;
using
ScaleTileIterator
=
ScaleTileIterator_
;
using
OutputTileIterator
=
OutputTileIterator_
;
using
ElementwiseFunctor
=
ElementwiseFunctor_
;
static
int
const
kIterations
=
OutputTileIterator
::
kIterations
;
static
int
const
kElementsPerAccess
=
OutputTileIterator
::
kElementsPerAccess
;
using
ElementOutput
=
typename
OutputTileIterator
::
Element
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
using
ElementAccumulator
=
ElementAccumulator_
;
using
AlphaScaleElementType
=
typename
ScaleTileIterator
::
Element
;
using
ElementCompute
=
ElementCompute_
;
using
AccumulatorFragment
=
Array
<
ElementAccumulator
,
kElementsPerAccess
>
;
using
ComputeFragment
=
Array
<
ElementCompute_
,
kElementsPerAccess
>
;
using
OutputVector
=
Array
<
ElementOutput
,
kElementsPerAccess
>
;
static
int
const
kThreadsPerRow
=
OutputTileIterator
::
ThreadMap
::
Detail
::
kAccessWidth
;
static
bool
const
kHasMultiStepsInRow
=
(
OutputTileIterator
::
ThreadMap
::
Iterations
::
kColumn
>
1
);
/// Argument structure
struct
Arguments
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
Arguments
()
:
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
0
),
batch_stride_C
(
0
),
batch_stride_D
(
0
)
{}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
),
batch_stride_alpha
(
batch_stride_alpha_
),
batch_stride_C
(
batch_stride_C_
),
batch_stride_D
(
batch_stride_D_
)
{}
};
struct
Params
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
{}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
)
:
elementwise
(
args
.
elementwise
),
batch_stride_alpha
(
args
.
batch_stride_alpha
),
batch_stride_C
(
args
.
batch_stride_C
),
batch_stride_D
(
args
.
batch_stride_D
)
{}
};
/// Shared storage
struct
SharedStorage
{};
private:
Params
const
&
params_
;
SharedStorage
&
shared_storage_
;
MatrixCoord
extent_
;
MatrixCoord
extent_real_
;
ElementwiseFunctor
elementwise_
;
bool
const
with_bias_
;
bool
const
per_token_quant_
;
bool
const
per_channel_quant_
;
AlphaScaleElementType
*
ptr_alpha_row_
;
AlphaScaleElementType
*
ptr_alpha_col_
;
ScaleTileIterator
iterator_alpha_col_
;
OutputTileIterator
iterator_C_
;
OutputTileIterator
iterator_D_
;
AlphaScaleElementType
element_alpha_row_
=
1.0
f
;
AlphaScaleElementType
element_alpha_col_
=
1.0
f
;
typename
ScaleTileIterator
::
Fragment
fragment_alpha_col_
;
typename
OutputTileIterator
::
Fragment
fragment_C_
;
typename
OutputTileIterator
::
Fragment
fragment_D_
;
ElementAccumulator
beta_
;
int
column_offset_
;
MatrixCoord
thread_offset_
;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
bool
with_bias
,
bool
per_token_quant
,
bool
per_channel_quant
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
),
shared_storage_
(
shared_storage
),
extent_
(
problem_size
),
elementwise_
(
params
.
elementwise
),
with_bias_
(
with_bias
),
per_token_quant_
(
per_token_quant
),
per_channel_quant_
(
per_channel_quant
),
ptr_alpha_row_
(
ptr_alpha_row
),
ptr_alpha_col_
(
ptr_alpha_col
),
iterator_alpha_col_
(
params_alpha_col
,
ptr_alpha_col
,
problem_size
,
thread_idx
,
threadblock_offset
),
iterator_C_
(
params_C
,
ptr_C
,
problem_size
,
thread_idx
,
threadblock_offset
),
iterator_D_
(
params_D
,
ptr_D
,
problem_size
,
thread_idx
,
threadblock_offset
),
extent_real_
(
problem_size_real
)
{
if
(
!
per_channel_quant_
&&
(
ptr_alpha_col_
!=
nullptr
))
{
element_alpha_col_
=
*
ptr_alpha_col_
;
}
if
(
!
per_token_quant_
&&
(
ptr_alpha_row_
!=
nullptr
))
{
element_alpha_row_
=
*
ptr_alpha_row_
;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void
set_batch_index
(
int
batch_idx
)
{
iterator_alpha_col_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_alpha
);
iterator_C_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_C
);
iterator_D_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_D
);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void
begin_epilogue
()
{
if
(
per_channel_quant_
)
{
iterator_alpha_col_
.
load
(
fragment_alpha_col_
);
}
if
(
with_bias_
)
{
iterator_C_
.
load
(
fragment_C_
);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void
begin_step
(
int
step_idx
)
{
fragment_D_
.
clear
();
}
/// Called at the start of a row
CUTLASS_DEVICE
void
begin_row
(
int
row_idx
)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if
(
per_token_quant_
)
{
int
thread_offset_row
=
iterator_D_
.
thread_start_row
()
+
OutputTileIterator
::
ThreadMap
::
iteration_offset
(
row_idx
).
row
();
arch
::
global_load
<
AlphaScaleElementType
,
sizeof
(
AlphaScaleElementType
)
>
(
element_alpha_row_
,
ptr_alpha_row_
+
thread_offset_row
,
thread_offset_row
<
extent_
.
row
());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frag_idx
,
AccumulatorFragment
const
&
accum
)
{
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kElementsPerAccess
>
source_converter
;
ComputeFragment
result
=
source_converter
(
accum
);
if
(
per_channel_quant_
)
{
ComputeFragment
alpha_col
=
reinterpret_cast
<
ComputeFragment
*>
(
&
fragment_alpha_col_
)[
column_idx
];
result
=
per_token_channel_scale_accumulator_
(
result
,
alpha_col
,
element_alpha_row_
);
}
else
{
result
=
per_token_scale_accumulator_
(
result
,
element_alpha_col_
,
element_alpha_row_
);
}
if
(
with_bias_
)
{
NumericArrayConverter
<
ElementCompute
,
ElementOutput
,
kElementsPerAccess
>
bias_converter
;
OutputVector
bias
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_C_
)[
column_idx
];
result
=
bias_accumulator_
(
result
,
bias_converter
(
bias
));
}
// Convert to the output
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kElementsPerAccess
>
output_converter
;
OutputVector
&
output
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_D_
)[
frag_idx
];
output
=
output_converter
(
result
);
}
/// Called at the end of a row
CUTLASS_DEVICE
void
end_row
(
int
row_idx
)
{}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void
end_step
(
int
step_idx
)
{
iterator_D_
.
store
(
fragment_D_
);
++
iterator_D_
;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void
end_epilogue
()
{}
private:
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
[
i
]
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
bias_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
bias
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
OutputVector
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
+
bias
[
i
];
}
return
result
;
}
};
}
// namespace threadblock
}
// namespace epilogue
}
// namespace cutlass
src/infiniop/ops/scaled_mm/nvidia/gemm_universal_base_compat.h
0 → 100644
View file @
8d09630a
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
#pragma once
#include <cutlass/cutlass.h>
#include <cutlass/device_kernel.h>
#include <cutlass/trace.h>
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template
<
typename
GemmKernel_
>
class
GemmUniversalBaseCompat
{
public:
using
GemmKernel
=
GemmKernel_
;
using
ThreadblockShape
=
typename
GemmKernel
::
Mma
::
Shape
;
using
ElementA
=
typename
GemmKernel
::
ElementA
;
using
LayoutA
=
typename
GemmKernel
::
LayoutA
;
using
TensorRefA
=
TensorRef
<
ElementA
const
,
LayoutA
>
;
static
ComplexTransform
const
kTransformA
=
GemmKernel
::
kTransformA
;
using
ElementB
=
typename
GemmKernel
::
ElementB
;
using
LayoutB
=
typename
GemmKernel
::
LayoutB
;
using
TensorRefB
=
TensorRef
<
ElementB
const
,
LayoutB
>
;
static
ComplexTransform
const
kTransformB
=
GemmKernel
::
kTransformB
;
using
ElementC
=
typename
GemmKernel
::
ElementC
;
using
LayoutC
=
typename
GemmKernel
::
LayoutC
;
using
TensorRefC
=
TensorRef
<
ElementC
const
,
LayoutC
>
;
using
TensorRefD
=
TensorRef
<
ElementC
,
LayoutC
>
;
using
ElementAccumulator
=
typename
GemmKernel
::
Mma
::
Policy
::
Operator
::
ElementC
;
using
EpilogueOutputOp
=
typename
GemmKernel
::
EpilogueOutputOp
;
using
ThreadblockSwizzle
=
typename
GemmKernel
::
ThreadblockSwizzle
;
using
Operator
=
typename
GemmKernel
::
Operator
;
/// Argument structure
using
Arguments
=
typename
GemmKernel
::
Arguments
;
protected:
/// Kernel parameters object
typename
GemmKernel
::
Params
params_
;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static
void
get_grid_shape_
(
gemm
::
GemmCoord
&
grid_tiled_shape
,
int
&
gemm_k_size
,
Arguments
const
&
args
)
{
// Determine grid shape
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
gemm_k_size
=
args
.
problem_size
.
k
();
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat
()
{}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
uint32_t
const
kGridYZMax
=
((
1
<<
(
sizeof
(
uint16_t
)
*
8
))
-
1
);
if
(
!
(
grid
.
y
<=
kGridYZMax
&&
grid
.
z
<=
kGridYZMax
))
{
return
Status
::
kErrorInvalidProblem
;
}
return
GemmKernel
::
can_implement
(
args
);
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_workspace_size()"
);
size_t
workspace_bytes
=
0
;
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes
=
sizeof
(
ElementC
)
*
size_t
(
args
.
batch_stride_D
)
*
size_t
(
grid_tiled_shape
.
k
());
}
else
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
&&
grid_tiled_shape
.
k
()
>
1
)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes
=
sizeof
(
int
)
*
size_t
(
grid_tiled_shape
.
m
())
*
size_t
(
grid_tiled_shape
.
n
());
}
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
workspace_bytes
+=
GemmKernel
::
get_extra_workspace_size
(
args
,
grid_tiled_shape
);
return
workspace_bytes
;
}
/// Computes the grid shape
static
dim3
get_grid_shape
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_grid_shape()"
);
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
smem_capacity
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
if
(
smem_capacity
<
0
)
{
int
device_idx
=
0
;
result
=
cudaGetDevice
(
&
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
cudaDeviceProp
properties
;
result
=
cudaGetDeviceProperties
(
&
properties
,
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
smem_capacity
=
static_cast
<
int
>
(
properties
.
sharedMemPerMultiprocessor
);
}
int
occupancy
=
std
::
min
(
max_active_blocks
,
smem_capacity
/
smem_size
);
CUTLASS_TRACE_HOST
(
" occupancy: "
<<
occupancy
);
return
occupancy
;
}
CUTLASS_TRACE_HOST
(
" returning internal error"
);
return
-
1
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
if
(
workspace_bytes
)
{
if
(
!
workspace
)
{
CUTLASS_TRACE_HOST
(
" error: device workspace must not be null"
);
return
Status
::
kErrorWorkspaceNull
;
}
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
)
{
CUTLASS_TRACE_HOST
(
" clearing device workspace"
);
cudaError_t
result
=
cudaMemsetAsync
(
workspace
,
0
,
workspace_bytes
,
stream
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaMemsetAsync() returned error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
}
// Get CUDA grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
// Initialize the Params structure
params_
=
typename
GemmKernel
::
Params
(
args
,
grid_tiled_shape
,
gemm_k_size
,
static_cast
<
int
*>
(
workspace
));
// Specify shared memory capacity for kernel.
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
if
(
smem_size
>=
(
48
<<
10
))
{
cudaError_t
result
=
cudaFuncSetAttribute
(
Kernel
<
GemmKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
return
Status
::
kErrorInternal
;
}
}
return
Status
::
kSuccess
;
}
/// Lightweight update given a subset of arguments
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
&&
!
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
params_
.
update
(
args
,
workspace
);
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::run()"
);
//
// Configure grid and block dimensions
//
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
params_
.
grid_tiled_shape
);
dim3
block
(
GemmKernel
::
kThreadCount
,
1
,
1
);
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
//
// Launch kernel
//
CUTLASS_TRACE_HOST
(
" grid: ("
<<
grid
<<
"), block: ("
<<
block
<<
"), SMEM: "
<<
smem_size
<<
" bytes"
);
// Launch
cutlass
::
Kernel
<
GemmKernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params_
);
//
// Query for errors
//
cudaError_t
result
=
cudaGetLastError
();
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" grid launch failed with error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
stream
);
}
/// Runs the kernel using initialized state.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
status
==
Status
::
kSuccess
)
{
status
=
run
(
stream
);
}
return
status
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace device
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
src/infiniop/ops/scaled_mm/nvidia/gemm_with_epilogue_visitor.h
0 → 100644
View file @
8d09630a
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Adapted from
// https://github.com/NVIDIA/TensorRT-LLM/blob/be1788106245496872d18e702978e59b6bfd50e0/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/gemm_with_epilogue_visitor.h
#pragma once
#include <cutlass/complex.h>
#include <cutlass/cutlass.h>
#include <cutlass/fast_math.h>
#include <cutlass/matrix_coord.h>
#include <cutlass/trace.h>
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Mma_
,
///! Threadblock-scoped matrix multiply-accumulate
typename
Epilogue_
,
///! Epilogue
typename
ThreadblockSwizzle_
///! Threadblock swizzling function
>
struct
GemmWithEpilogueVisitor
{
public:
using
Mma
=
Mma_
;
using
Epilogue
=
Epilogue_
;
using
EpilogueVisitor
=
typename
Epilogue
::
Visitor
;
using
ThreadblockSwizzle
=
ThreadblockSwizzle_
;
using
ElementA
=
typename
Mma
::
IteratorA
::
Element
;
using
LayoutA
=
typename
Mma
::
IteratorA
::
Layout
;
using
TensorRefA
=
TensorRef
<
ElementA
,
LayoutA
>
;
using
ElementB
=
typename
Mma
::
IteratorB
::
Element
;
using
LayoutB
=
typename
Mma
::
IteratorB
::
Layout
;
using
TensorRefB
=
TensorRef
<
ElementB
,
LayoutB
>
;
using
ElementCompute
=
typename
EpilogueVisitor
::
ElementCompute
;
using
LayoutAlphaCol
=
cutlass
::
layout
::
RowMajor
;
using
LayoutAlphaRow
=
cutlass
::
layout
::
ColumnMajor
;
using
TensorRefAlphaCol
=
TensorRef
<
ElementCompute
,
LayoutAlphaCol
>
;
using
TensorRefAlphaRow
=
TensorRef
<
ElementCompute
,
LayoutAlphaRow
>
;
using
ElementC
=
typename
EpilogueVisitor
::
ElementOutput
;
using
LayoutC
=
typename
Epilogue
::
Layout
;
using
TensorRefC
=
TensorRef
<
ElementC
,
LayoutC
>
;
static
ComplexTransform
const
kTransformA
=
Mma
::
kTransformA
;
static
ComplexTransform
const
kTransformB
=
Mma
::
kTransformB
;
using
Operator
=
typename
Mma
::
Operator
;
using
OperatorClass
=
typename
Mma
::
Operator
::
OperatorClass
;
using
ThreadblockShape
=
typename
Mma
::
Shape
;
using
WarpShape
=
typename
Mma
::
Operator
::
Shape
;
using
InstructionShape
=
typename
Mma
::
Policy
::
Operator
::
InstructionShape
;
using
ArchTag
=
typename
Mma
::
ArchTag
;
using
EpilogueOutputOp
=
typename
Epilogue
::
Visitor
::
ElementwiseFunctor
;
// Define type so GemmUniversalBase doesn't complain
static
int
const
kStages
=
Mma
::
kStages
;
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
kElementsPerAccess
;
/// Warp count (concept: GemmShape)
using
WarpCount
=
typename
Mma
::
WarpCount
;
static
int
const
kThreadCount
=
32
*
WarpCount
::
kCount
;
/// Split-K preserves splits that are 128b aligned
static
int
const
kSplitKAlignment
=
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
);
//
// Structures
//
/// Argument structure
struct
Arguments
{
//
// Data members
//
GemmUniversalMode
mode
;
GemmCoord
problem_size
;
int
batch_count
;
TensorRefA
ref_A
;
TensorRefB
ref_B
;
TensorRefAlphaCol
ref_alpha_col
;
TensorRefAlphaRow
ref_alpha_row
;
TensorRefC
ref_C
;
TensorRefC
ref_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
int64_t
batch_stride_D
;
typename
EpilogueVisitor
::
Arguments
epilogue_visitor
;
//
// Methods
//
Arguments
()
:
mode
(
GemmUniversalMode
::
kGemm
),
batch_count
(
1
)
{}
/// constructs an arguments structure
Arguments
(
GemmCoord
problem_size_
,
TensorRefA
ref_A_
,
TensorRefB
ref_B_
,
TensorRefAlphaCol
ref_alpha_col_
,
TensorRefAlphaRow
ref_alpha_row_
,
TensorRefC
ref_C_
,
TensorRefC
ref_D_
,
typename
EpilogueVisitor
::
Arguments
epilogue_visitor_
)
:
mode
(
GemmUniversalMode
::
kGemm
),
problem_size
(
problem_size_
),
batch_count
(
1
),
ref_A
(
ref_A_
),
ref_B
(
ref_B_
),
ref_alpha_col
(
ref_alpha_col_
),
ref_alpha_row
(
ref_alpha_row_
),
ref_C
(
ref_C_
),
ref_D
(
ref_D_
),
batch_stride_A
(
0
),
batch_stride_B
(
0
),
batch_stride_D
(
0
),
epilogue_visitor
(
epilogue_visitor_
)
{}
};
//
// Structure for precomputing values in host memory and passing to kernels
//
/// Parameters structure
struct
Params
{
cutlass
::
gemm
::
GemmCoord
problem_size
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
swizzle_log_tile
;
typename
Mma
::
IteratorA
::
Params
params_A
;
typename
Mma
::
IteratorB
::
Params
params_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Params
params_alpha_row
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_C
;
typename
EpilogueVisitor
::
OutputTileIterator
::
Params
params_D
;
GemmUniversalMode
mode
;
int
batch_count
;
int
gemm_k_size
;
void
*
ptr_A
;
void
*
ptr_B
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_col
;
typename
EpilogueVisitor
::
ScaleTileIterator
::
Element
*
ptr_alpha_row
;
ElementC
*
ptr_C
;
ElementC
*
ptr_D
;
int64_t
batch_stride_A
;
int64_t
batch_stride_B
;
typename
EpilogueVisitor
::
Params
epilogue_visitor
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
:
swizzle_log_tile
(
0
),
params_A
(
0
),
params_B
(
0
),
params_alpha_col
(
0
),
params_C
(
0
),
params_D
(
0
),
batch_count
(
0
),
gemm_k_size
(
0
),
mode
(
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
),
ptr_A
(
nullptr
),
ptr_B
(
nullptr
),
ptr_alpha_col
(
nullptr
),
ptr_alpha_row
(
nullptr
),
ptr_C
(
nullptr
),
ptr_D
(
nullptr
),
batch_stride_A
(
0
),
batch_stride_B
(
0
)
{}
Params
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape_
,
int
gemm_k_size_
,
int
*
workspace_
)
:
problem_size
(
args
.
problem_size
),
swizzle_log_tile
(
0
),
params_A
(
args
.
ref_A
.
layout
()),
params_B
(
args
.
ref_B
.
layout
()),
params_alpha_col
(
args
.
ref_alpha_col
.
layout
()),
params_alpha_row
(
args
.
ref_alpha_col
.
layout
()),
params_C
(
args
.
ref_C
.
layout
()),
params_D
(
args
.
ref_D
.
layout
()),
mode
(
args
.
mode
),
batch_count
(
args
.
batch_count
),
gemm_k_size
(
args
.
problem_size
.
k
()),
ptr_A
(
args
.
ref_A
.
data
()),
ptr_B
(
args
.
ref_B
.
data
()),
ptr_alpha_col
(
args
.
ref_alpha_col
.
data
()),
ptr_alpha_row
(
args
.
ref_alpha_row
.
data
()),
ptr_C
(
args
.
ref_C
.
data
()),
ptr_D
(
args
.
ref_D
.
data
()),
batch_stride_A
(
args
.
batch_stride_A
),
batch_stride_B
(
args
.
batch_stride_B
),
epilogue_visitor
(
args
.
epilogue_visitor
)
{
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
swizzle_log_tile
=
threadblock_swizzle
.
get_log_tile
(
grid_tiled_shape
);
}
};
/// Shared memory storage structure
union
SharedStorage
{
typename
Mma
::
SharedStorage
main_loop
;
struct
{
typename
Epilogue
::
SharedStorage
epilogue
;
typename
EpilogueVisitor
::
SharedStorage
visitor
;
}
epilogue
;
};
public:
//
// Methods
//
CUTLASS_DEVICE
GemmWithEpilogueVisitor
()
{}
/// Determines whether kernel satisfies alignment
static
Status
can_implement
(
cutlass
::
gemm
::
GemmCoord
const
&
problem_size
)
{
CUTLASS_TRACE_HOST
(
"GemmWithEpilogueVisitor::can_implement()"
);
static
int
const
kAlignmentA
=
Mma
::
IteratorA
::
AccessType
::
kElements
;
static
int
const
kAlignmentB
=
Mma
::
IteratorB
::
AccessType
::
kElements
;
static
int
const
kAlignmentC
=
EpilogueVisitor
::
OutputTileIterator
::
kElementsPerAccess
;
bool
isAMisaligned
=
false
;
bool
isBMisaligned
=
false
;
bool
isCMisaligned
=
false
;
if
(
platform
::
is_same
<
LayoutA
,
layout
::
RowMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajor
>::
value
)
{
isAMisaligned
=
problem_size
.
m
()
%
kAlignmentA
;
}
else
if
(
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutA
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isAMisaligned
=
problem_size
.
k
()
%
kAlignmentA
;
}
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
n
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
ColumnMajor
>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
else
if
(
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutB
,
layout
::
RowMajorInterleaved
<
64
>>::
value
)
{
isBMisaligned
=
problem_size
.
k
()
%
kAlignmentB
;
}
if
(
platform
::
is_same
<
LayoutC
,
layout
::
RowMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajor
>::
value
)
{
isCMisaligned
=
problem_size
.
m
()
%
kAlignmentC
;
}
else
if
(
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
32
>>::
value
||
platform
::
is_same
<
LayoutC
,
layout
::
ColumnMajorInterleaved
<
64
>>::
value
)
{
isCMisaligned
=
problem_size
.
n
()
%
kAlignmentC
;
}
if
(
isAMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for A operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isBMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for B operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
if
(
isCMisaligned
)
{
CUTLASS_TRACE_HOST
(
" returning kErrorMisalignedOperand for C operand"
);
return
Status
::
kErrorMisalignedOperand
;
}
CUTLASS_TRACE_HOST
(
" returning kSuccess"
);
return
Status
::
kSuccess
;
}
static
Status
can_implement
(
Arguments
const
&
args
)
{
return
can_implement
(
args
.
problem_size
);
}
static
size_t
get_extra_workspace_size
(
Arguments
const
&
args
,
cutlass
::
gemm
::
GemmCoord
const
&
grid_tiled_shape
)
{
return
0
;
}
#define SPLIT_K_ENABLED 1
/// Executes one GEMM
CUTLASS_DEVICE
void
run_kernel_
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
// Compute threadblock location
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// Early exit if CTA is out of range
if
(
params
.
grid_tiled_shape
.
m
()
<=
threadblock_tile_offset
.
m
()
||
params
.
grid_tiled_shape
.
n
()
<=
threadblock_tile_offset
.
n
())
{
return
;
}
int
offset_k
=
0
;
int
problem_size_k
=
params
.
problem_size
.
k
();
ElementA
*
ptr_A
=
static_cast
<
ElementA
*>
(
params
.
ptr_A
);
ElementB
*
ptr_B
=
static_cast
<
ElementB
*>
(
params
.
ptr_B
);
#if SPLIT_K_ENABLED
//
// Fetch pointers based on mode.
//
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
||
params
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
if
(
threadblock_tile_offset
.
k
()
+
1
<
params
.
grid_tiled_shape
.
k
())
{
problem_size_k
=
(
threadblock_tile_offset
.
k
()
+
1
)
*
params
.
gemm_k_size
;
}
offset_k
=
threadblock_tile_offset
.
k
()
*
params
.
gemm_k_size
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
)
{
ptr_A
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_A
;
ptr_B
+=
threadblock_tile_offset
.
k
()
*
params
.
batch_stride_B
;
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
ptr_A
=
static_cast
<
ElementA
*
const
*>
(
params
.
ptr_A
)[
threadblock_tile_offset
.
k
()];
ptr_B
=
static_cast
<
ElementB
*
const
*>
(
params
.
ptr_B
)[
threadblock_tile_offset
.
k
()];
}
#endif
// Compute initial location in logical coordinates
cutlass
::
MatrixCoord
tb_offset_A
{
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
offset_k
,
};
cutlass
::
MatrixCoord
tb_offset_B
{
offset_k
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
};
// Compute position within threadblock
int
thread_idx
=
threadIdx
.
x
;
// Construct iterators to A and B operands
typename
Mma
::
IteratorA
iterator_A
(
params
.
params_A
,
ptr_A
,
{
params
.
problem_size
.
m
(),
problem_size_k
},
thread_idx
,
tb_offset_A
);
typename
Mma
::
IteratorB
iterator_B
(
params
.
params_B
,
ptr_B
,
{
problem_size_k
,
params
.
problem_size
.
n
()},
thread_idx
,
tb_offset_B
);
// Broadcast the warp_id computed by lane 0 to ensure dependent code
// is compiled as warp-uniform.
int
warp_idx
=
__shfl_sync
(
0xffffffff
,
threadIdx
.
x
/
32
,
0
);
int
lane_idx
=
threadIdx
.
x
%
32
;
//
// Main loop
//
// Construct thread-scoped matrix multiply
Mma
mma
(
shared_storage
.
main_loop
,
thread_idx
,
warp_idx
,
lane_idx
);
typename
Mma
::
FragmentC
accumulators
;
accumulators
.
clear
();
// Compute threadblock-scoped matrix multiply-add
int
gemm_k_iterations
=
(
problem_size_k
-
offset_k
+
Mma
::
Shape
::
kK
-
1
)
/
Mma
::
Shape
::
kK
;
// Compute threadblock-scoped matrix multiply-add
mma
(
gemm_k_iterations
,
accumulators
,
iterator_A
,
iterator_B
,
accumulators
);
//
// Masked tile iterators constructed from members
//
threadblock_tile_offset
=
threadblock_swizzle
.
get_tile_offset
(
params
.
swizzle_log_tile
);
// assume identity swizzle
MatrixCoord
threadblock_offset
(
threadblock_tile_offset
.
m
()
*
Mma
::
Shape
::
kM
,
threadblock_tile_offset
.
n
()
*
Mma
::
Shape
::
kN
);
int
block_idx
=
threadblock_tile_offset
.
m
()
+
threadblock_tile_offset
.
n
()
*
params
.
grid_tiled_shape
.
m
();
//
// Construct the epilogue visitor
//
bool
with_bias
=
true
;
if
(
params
.
ptr_C
==
nullptr
)
{
with_bias
=
false
;
}
EpilogueVisitor
epilogue_visitor
(
params
.
epilogue_visitor
,
shared_storage
.
epilogue
.
visitor
,
params
.
problem_size
.
mn
(),
thread_idx
,
warp_idx
,
lane_idx
,
params
.
params_alpha_col
,
params
.
params_C
,
params
.
params_D
,
with_bias
,
true
,
true
,
params
.
ptr_alpha_row
,
params
.
ptr_alpha_col
,
params
.
ptr_C
,
params
.
ptr_D
,
threadblock_offset
,
blockIdx
.
y
*
params
.
problem_size
.
m
());
if
(
params
.
mode
==
GemmUniversalMode
::
kGemm
)
{
// Indicate which position in a serial reduction the output operator is currently updating
epilogue_visitor
.
set_k_partition
(
threadblock_tile_offset
.
k
(),
params
.
grid_tiled_shape
.
k
());
}
else
if
(
params
.
mode
==
GemmUniversalMode
::
kBatched
||
params
.
mode
==
GemmUniversalMode
::
kArray
)
{
epilogue_visitor
.
set_batch_index
(
threadblock_tile_offset
.
k
());
}
// Construct the epilogue
Epilogue
epilogue
(
shared_storage
.
epilogue
.
epilogue
,
thread_idx
,
warp_idx
,
lane_idx
);
// Execute the epilogue operator to update the destination tensor.
epilogue
(
epilogue_visitor
,
accumulators
);
}
template
<
typename
CompilationArch
>
CUTLASS_DEVICE
void
run_kernel
(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
if
constexpr
(
platform
::
is_same
<
ArchTag
,
CompilationArch
>::
value
)
{
run_kernel_
(
params
,
shared_storage
);
}
else
{
CUTLASS_NOT_IMPLEMENTED
();
}
}
/// Executes one GEMM
CUTLASS_DEVICE
void
operator
()(
Params
const
&
params
,
SharedStorage
&
shared_storage
)
{
run_kernel
<
ArchTag
>
(
params
,
shared_storage
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_kernel.cuh
0 → 100644
View file @
8d09630a
/* Copyright 2025 SGLang Team. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cutlass/cutlass.h>
#include <cutlass/epilogue/thread/linear_combination.h>
#include <cutlass/epilogue/threadblock/epilogue_with_visitor.h>
#include <cutlass/gemm/device/gemm.h>
#include <cutlass/gemm/device/gemm_universal_adapter.h>
#include <cutlass/numeric_types.h>
#include <cute/atom/mma_atom.hpp>
#include <cute/tensor.hpp>
#include <cutlass/epilogue/collective/collective_builder.hpp>
#include <cutlass/gemm/collective/collective_builder.hpp>
#include <cutlass/gemm/kernel/gemm_universal.hpp>
#include <cutlass/util/packed_stride.hpp>
#include "../../../devices/nvidia/nvidia_common.cuh"
#include "epilogue_per_row_per_col_scale.h"
#include "gemm_universal_base_compat.h"
#include "gemm_with_epilogue_visitor.h"
using
namespace
cute
;
inline
infiniStatus_t
check_cutlass_status
(
cutlass
::
Status
status
)
{
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
return
INFINI_STATUS_INTERNAL_ERROR
;
}
return
INFINI_STATUS_SUCCESS
;
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
ThreadblockShape
,
typename
WarpShape
,
typename
InstructionShape
,
int
NumStages
>
void
cutlass_int8_scaled_mm
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
using
ElementAccumulator
=
int32_t
;
using
ElementCompute
=
float
;
using
ElementInputA
=
int8_t
;
using
ElementInputB
=
int8_t
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
ThreadblockSwizzle
=
cutlass
::
gemm
::
threadblock
::
GemmIdentityThreadblockSwizzle
<
8
>
;
using
DefaultGemmConf
=
cutlass
::
gemm
::
device
::
DefaultGemmConfiguration
<
OperatorClass
,
ArchTag
,
ElementInputA
,
ElementInputB
,
ElementOutput
,
ElementCompute
>
;
using
EpilogueOutputOp
=
typename
DefaultGemmConf
::
EpilogueOutputOp
;
using
GemmKernel_
=
typename
cutlass
::
gemm
::
kernel
::
DefaultGemm
<
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
DefaultGemmConf
::
kAlignmentA
,
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
DefaultGemmConf
::
kAlignmentB
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
ElementAccumulator
,
OperatorClass
,
ArchTag
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
EpilogueOutputOp
,
ThreadblockSwizzle
,
NumStages
,
true
,
typename
DefaultGemmConf
::
Operator
>::
GemmKernel
;
using
AlphaColTileIterator
=
cutlass
::
epilogue
::
threadblock
::
PredicatedTileIterator
<
cutlass
::
epilogue
::
threadblock
::
OutputTileOptimalThreadMap
<
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Shape
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
Count
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
ThreadMap
::
kThreads
,
GemmKernel_
::
Epilogue
::
OutputTileIterator
::
kElementsPerAccess
,
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
>
,
ElementCompute
>
;
using
EpilogueVisitor
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueVisitorPerRowPerCol
<
ThreadblockShape
,
GemmKernel_
::
kThreadCount
,
AlphaColTileIterator
,
typename
GemmKernel_
::
Epilogue
::
OutputTileIterator
,
ElementAccumulator
,
ElementCompute
,
EpilogueOutputOp
>
;
using
Epilogue
=
typename
cutlass
::
epilogue
::
threadblock
::
EpilogueWithVisitorFromExistingEpilogue
<
EpilogueVisitor
,
typename
GemmKernel_
::
Epilogue
>::
Epilogue
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmWithEpilogueVisitor
<
typename
GemmKernel_
::
Mma
,
Epilogue
,
ThreadblockSwizzle
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalBaseCompat
<
GemmKernel
>
;
Gemm
gemm_op
;
auto
a_ptr
=
static_cast
<
ElementInputA
*>
(
const_cast
<
void
*>
(
a
));
auto
b_ptr
=
static_cast
<
ElementInputB
*>
(
const_cast
<
void
*>
(
b
));
auto
o_ptr
=
static_cast
<
ElementOutput
*>
(
const_cast
<
void
*>
(
out
));
auto
a_s_ptr
=
static_cast
<
ElementCompute
*>
(
const_cast
<
void
*>
(
a_scale
));
auto
b_s_ptr
=
static_cast
<
ElementCompute
*>
(
const_cast
<
void
*>
(
b_scale
));
ElementOutput
*
bias_ptr
=
nullptr
;
int64_t
ldc
=
0
;
if
(
bias
)
{
bias_ptr
=
static_cast
<
ElementOutput
*>
(
const_cast
<
void
*>
(
bias
));
}
typename
EpilogueOutputOp
::
Params
linearScalingParams
;
typename
EpilogueVisitor
::
Arguments
visitor_args
{
linearScalingParams
};
typename
Gemm
::
Arguments
args
{
{
m
,
n
,
k
},
{
a_ptr
,
lda
},
{
b_ptr
,
ldb
},
{
b_s_ptr
,
0
},
{
a_s_ptr
,
0
},
{
bias_ptr
,
ldc
},
{
o_ptr
,
ldd
},
visitor_args
};
check_cutlass_status
(
gemm_op
.
can_implement
(
args
));
auto
status
=
gemm_op
(
args
,
nullptr
,
(
cudaStream_t
)
stream
);
check_cutlass_status
(
status
);
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm75_dispatch_shape
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
if
(
m
<=
32
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
if
(
m
<=
64
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
if
(
m
<=
256
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
2
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm80_dispatch_shape
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
if
(
m
<=
16
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
32
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
6
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
64
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
128
&&
n
<
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
// Dispatch shape for sm89 (L40S, L20, RTX 4090), according to:
// https://github.com/vllm-project/vllm/blob/main/csrc/quantization/cutlass_w8a8/scaled_mm_c2x_sm89_int8_dispatch.cuh
template
<
typename
ElementOutput
,
typename
ArchTag
,
typename
InstructionShape
>
void
sm89_dispatch_shape
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
if
(
m
<=
16
)
{
if
(
n
<=
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
16
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
4
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
32
)
{
if
(
n
<=
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
16
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
32
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
4
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
64
)
{
if
(
n
<=
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
3
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
128
)
{
if
(
n
<=
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
3
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
if
(
n
<=
16384
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
32
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
256
)
{
if
(
n
<=
4096
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
64
,
128
,
128
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
3
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
if
(
n
<=
8192
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
if
(
n
<=
16384
)
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
256
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
3
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
{
cutlass_int8_scaled_mm
<
ElementOutput
,
ArchTag
,
cutlass
::
gemm
::
GemmShape
<
128
,
128
,
64
>
,
cutlass
::
gemm
::
GemmShape
<
64
,
64
,
64
>
,
InstructionShape
,
5
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
template
<
typename
ElementOutput
,
typename
TileShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
,
bool
WithBias
>
void
cutlass_int8_scaled_mm_sm90
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
using
ArchTag
=
cutlass
::
arch
::
Sm90
;
using
ElementAccumulator
=
int32_t
;
using
ElementCompute
=
float
;
using
ElementInputA
=
int8_t
;
using
ElementInputB
=
int8_t
;
static
constexpr
int
AlignmentA
=
128
/
cutlass
::
sizeof_bits
<
ElementInputA
>::
value
;
static
constexpr
int
AlignmentB
=
128
/
cutlass
::
sizeof_bits
<
ElementInputB
>::
value
;
static
constexpr
int
AlignmentC
=
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
;
static
constexpr
int
AlignmentOutput
=
128
/
cutlass
::
sizeof_bits
<
ElementOutput
>::
value
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
EpilogueScheduleType
=
cutlass
::
epilogue
::
TmaWarpSpecialized
;
using
TileSchedulerType
=
cutlass
::
gemm
::
PersistentScheduler
;
using
XScale
=
cutlass
::
epilogue
::
fusion
::
Sm90ColBroadcast
<
0
,
TileShape
,
ElementCompute
,
ElementCompute
,
Stride
<
Int
<
1
>
,
Int
<
0
>
,
Int
<
0
>>>
;
using
WScale
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementCompute
,
ElementCompute
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Bias
=
cutlass
::
epilogue
::
fusion
::
Sm90RowBroadcast
<
0
,
TileShape
,
ElementOutput
,
ElementOutput
,
Stride
<
Int
<
0
>
,
Int
<
1
>
,
Int
<
0
>>>
;
using
Accum
=
cutlass
::
epilogue
::
fusion
::
Sm90AccFetch
;
// Scale
using
Compute0
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementCompute
,
ElementCompute
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute0
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute0
,
WScale
,
Accum
>
;
using
Compute1
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiplies
,
ElementOutput
,
ElementCompute
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTCompute1
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
Compute1
,
XScale
,
EVTCompute0
>
;
// With bias
using
ComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90Compute
<
cutlass
::
multiply_add
,
ElementOutput
,
ElementCompute
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
>
;
using
EVTComputeWithBias
=
cutlass
::
epilogue
::
fusion
::
Sm90EVT
<
ComputeWithBias
,
XScale
,
EVTCompute0
,
Bias
>
;
using
EpilogueEVT
=
typename
cutlass
::
platform
::
conditional
<
WithBias
,
EVTComputeWithBias
,
EVTCompute1
>::
type
;
using
CollectiveEpilogue
=
typename
cutlass
::
epilogue
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
TileShape
,
ClusterShape
,
cutlass
::
epilogue
::
collective
::
EpilogueTileAuto
,
ElementAccumulator
,
ElementCompute
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
AlignmentC
,
ElementOutput
,
cutlass
::
layout
::
RowMajor
,
AlignmentOutput
,
EpilogueScheduleType
,
EpilogueEVT
>::
CollectiveOp
;
using
Stages
=
cutlass
::
gemm
::
collective
::
StageCountAutoCarveout
<
static_cast
<
int
>
(
sizeof
(
typename
CollectiveEpilogue
::
SharedStorage
))
>
;
using
CollectiveMainloop
=
typename
cutlass
::
gemm
::
collective
::
CollectiveBuilder
<
ArchTag
,
OperatorClass
,
ElementInputA
,
cutlass
::
layout
::
RowMajor
,
AlignmentA
,
ElementInputB
,
cutlass
::
layout
::
ColumnMajor
,
AlignmentB
,
ElementAccumulator
,
TileShape
,
ClusterShape
,
Stages
,
MainloopScheduleType
>::
CollectiveOp
;
using
GemmKernel
=
cutlass
::
gemm
::
kernel
::
GemmUniversal
<
Shape
<
int
,
int
,
int
,
int
>
,
// Indicates ProblemShape
CollectiveMainloop
,
CollectiveEpilogue
,
TileSchedulerType
>
;
using
Gemm
=
cutlass
::
gemm
::
device
::
GemmUniversalAdapter
<
GemmKernel
>
;
Gemm
gemm_op
;
auto
a_ptr
=
static_cast
<
ElementInputA
*>
(
const_cast
<
void
*>
(
a
));
auto
b_ptr
=
static_cast
<
ElementInputB
*>
(
const_cast
<
void
*>
(
b
));
auto
o_ptr
=
static_cast
<
ElementOutput
*>
(
const_cast
<
void
*>
(
out
));
auto
a_s_ptr
=
static_cast
<
ElementCompute
*>
(
const_cast
<
void
*>
(
a_scale
));
auto
b_s_ptr
=
static_cast
<
ElementCompute
*>
(
const_cast
<
void
*>
(
b_scale
));
using
StrideA
=
typename
Gemm
::
GemmKernel
::
StrideA
;
using
StrideB
=
typename
Gemm
::
GemmKernel
::
StrideB
;
using
StrideC
=
typename
Gemm
::
GemmKernel
::
StrideC
;
using
StrideD
=
typename
Gemm
::
GemmKernel
::
StrideD
;
StrideA
stride_a
=
cutlass
::
make_cute_packed_stride
(
StrideA
{},
make_shape
(
m
,
k
,
1
));
StrideB
stride_b
=
cutlass
::
make_cute_packed_stride
(
StrideB
{},
make_shape
(
n
,
k
,
1
));
StrideC
stride_c
;
StrideD
stride_d
=
cutlass
::
make_cute_packed_stride
(
StrideD
{},
make_shape
(
m
,
n
,
1
));
typename
Gemm
::
Arguments
args
=
{
cutlass
::
gemm
::
GemmUniversalMode
::
kGemm
,
{
m
,
n
,
k
,
1
},
{
a_ptr
,
stride_a
,
b_ptr
,
stride_b
},
{{},
// epilogue.thread
nullptr
,
stride_c
,
o_ptr
,
stride_d
}};
if
constexpr
(
WithBias
)
{
ElementOutput
*
bias_ptr
=
static_cast
<
ElementOutput
*>
(
const_cast
<
void
*>
(
bias
));
// ElementOutput* bias_ptr = static_cast<ElementOutput*>(bias->data_ptr());
args
.
epilogue
.
thread
=
{
{
a_s_ptr
},
{{
b_s_ptr
},
{},
{}},
{
bias_ptr
},
{},
};
}
else
{
args
.
epilogue
.
thread
=
{
{
a_s_ptr
},
{{
b_s_ptr
},
{},
{}},
{},
};
}
// auto workspace = torch::empty(
// gemm_op.get_workspace_size(args), torch::TensorOptions().dtype(torch::kUInt8).device(mat_a.device()));
check_cutlass_status
(
gemm_op
.
can_implement
(
args
));
auto
status
=
gemm_op
(
args
,
nullptr
,
(
cudaStream_t
)
stream
);
check_cutlass_status
(
status
);
// TORCH_CHECK(status == cutlass::Status::kSuccess, "gemm executioin failed, error: ", cutlassGetStatusString(status));
}
template
<
typename
ElementOutput
,
typename
TileShape
,
typename
ClusterShape
,
typename
MainloopScheduleType
>
void
sm90_dispatch_bias
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
if
(
bias
)
{
cutlass_int8_scaled_mm_sm90
<
ElementOutput
,
TileShape
,
ClusterShape
,
MainloopScheduleType
,
true
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
cutlass_int8_scaled_mm_sm90
<
ElementOutput
,
TileShape
,
ClusterShape
,
MainloopScheduleType
,
false
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
template
<
typename
ElementOutput
>
void
sm90_dispatch_shape
(
void
*
out
,
const
void
*
a
,
const
void
*
b
,
const
void
*
a_scale
,
const
void
*
b_scale
,
const
void
*
bias
,
int
m
,
int
n
,
int
k
,
int
lda
,
int
ldb
,
int
ldd
,
void
*
stream
)
{
if
(
m
<=
32
)
{
if
(
n
<
8192
)
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_128
,
_128
>
,
Shape
<
_1
,
_8
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
64
)
{
if
(
n
<
8192
)
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_1
,
_4
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_256
>
,
Shape
<
_1
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
if
(
m
<=
128
)
{
if
(
n
<=
4096
)
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_64
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_64
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecialized
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
else
{
return
sm90_dispatch_bias
<
ElementOutput
,
Shape
<
_128
,
_128
,
_128
>
,
Shape
<
_2
,
_1
,
_1
>
,
cutlass
::
gemm
::
KernelTmaWarpSpecializedPingpong
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
m
,
n
,
k
,
lda
,
ldb
,
ldd
,
stream
);
}
}
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cu
0 → 100644
View file @
8d09630a
#include "../../../devices/nvidia/nvidia_handle.cuh"
#include "../../../devices/nvidia/nvidia_kernel_common.cuh"
#ifdef ENABLE_CUTLASS_API
#include "int8_gemm_kernel.cuh"
#endif
#include "../cuda/per_channel_dequant_int8.cuh"
#include "int8_gemm_nvidia.cuh"
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
postSym
(
Tdata
*
y
,
int32_t
*
y_packed
,
const
Tdata
*
bias
,
const
int8_t
*
x_packed
,
const
float
*
x_scale
,
const
int8_t
*
w_packed
,
const
float
*
w_scale
,
int
M
,
int
K
,
int
N
)
{
postSymKernel
<
Tdata
>
(
y
,
y_packed
,
bias
,
x_packed
,
x_scale
,
w_packed
,
w_scale
,
M
,
K
,
N
);
}
template
<
typename
Tdata
>
INFINIOP_CUDA_KERNEL
postSym
(
Tdata
*
y
,
int32_t
*
y_packed
,
const
int8_t
*
x_packed
,
const
float
*
x_scale
,
const
int8_t
*
w_packed
,
const
float
*
w_scale
,
int
M
,
int
K
,
int
N
)
{
postSymKernel
<
Tdata
>
(
y
,
y_packed
,
x_packed
,
x_scale
,
w_packed
,
w_scale
,
M
,
K
,
N
);
}
namespace
op
::
i8gemm
::
nvidia
{
struct
Descriptor
::
Opaque
{
std
::
shared_ptr
<
device
::
nvidia
::
Handle
::
Internal
>
internal
;
};
Descriptor
::~
Descriptor
()
{
delete
_opaque
;
}
#ifdef ENABLE_NVIDIA_API
inline
int
getSMVersion
()
{
int
device
{
-
1
};
CHECK_CUDA
(
cudaGetDevice
(
&
device
));
int
sm_major
=
0
;
int
sm_minor
=
0
;
CHECK_CUDA
(
cudaDeviceGetAttribute
(
&
sm_major
,
cudaDevAttrComputeCapabilityMajor
,
device
));
CHECK_CUDA
(
cudaDeviceGetAttribute
(
&
sm_minor
,
cudaDevAttrComputeCapabilityMinor
,
device
));
return
sm_major
*
10
+
sm_minor
;
}
#endif
infiniStatus_t
Descriptor
::
create
(
infiniopHandle_t
handle_
,
Descriptor
**
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
bias_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
a_scale_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
b_scale_desc
)
{
auto
handle
=
reinterpret_cast
<
device
::
nvidia
::
Handle
*>
(
handle_
);
auto
dtype
=
out_desc
->
dtype
();
CHECK_DTYPE
(
dtype
,
INFINI_DTYPE_F16
,
INFINI_DTYPE_BF16
);
auto
result
=
I8GemmInfo
::
create
(
out_desc
,
a_desc
,
b_desc
,
MatrixLayout
::
COL_MAJOR
);
CHECK_RESULT
(
result
);
size_t
workspace_size
=
out_desc
->
dim
(
0
)
*
out_desc
->
dim
(
1
)
*
sizeof
(
int32_t
);
*
desc_ptr
=
new
Descriptor
(
new
Opaque
{
handle
->
internal
()},
result
.
take
(),
workspace_size
,
dtype
,
handle
->
device
,
handle
->
device_id
);
return
INFINI_STATUS_SUCCESS
;
}
template
<
unsigned
int
BLOCK_SIZE
,
typename
Tdata
>
infiniStatus_t
Descriptor
::
launchKernel
(
const
I8GemmInfo
&
info
,
Tdata
*
y
,
const
Tdata
*
bias
,
const
int8_t
*
x_packed
,
const
float
*
x_scale
,
const
int8_t
*
w_packed
,
const
float
*
w_scale
,
void
*
stream_
,
void
*
workspace
)
const
{
cudaStream_t
stream
=
(
cudaStream_t
)
stream_
;
int
M
=
(
int
)
info
.
m
;
int
K
=
(
int
)
info
.
k
;
int
N
=
(
int
)
info
.
n
;
char
*
workspace_ptr
=
reinterpret_cast
<
char
*>
(
workspace
);
int32_t
*
y_packed
=
reinterpret_cast
<
int32_t
*>
(
workspace_ptr
);
const
int32_t
alpha_I
=
1
;
const
int32_t
beta_I
=
0
;
int
lda
=
K
;
// w_packed is column-major [K, N]
int
ldb
=
K
;
// x_packed is row-major [M, K]
int
ldc
=
N
;
// y_packed is row-major [M, N]
CHECK_STATUS
(
this
->
_opaque
->
internal
->
useCublas
(
stream
,
[
&
](
cublasHandle_t
handle
)
{
CHECK_CUBLAS
(
cublasGemmEx
(
handle
,
CUBLAS_OP_T
,
// A = w_packed^T : [N, K]
CUBLAS_OP_N
,
// B = x_packed^T viewed column-major : [K, M]
N
,
// m
M
,
// n
K
,
// k
&
alpha_I
,
w_packed
,
CUDA_R_8I
,
lda
,
x_packed
,
CUDA_R_8I
,
ldb
,
&
beta_I
,
y_packed
,
CUDA_R_32I
,
ldc
,
CUBLAS_COMPUTE_32I
,
CUBLAS_GEMM_DEFAULT
));
return
INFINI_STATUS_SUCCESS
;
}));
constexpr
unsigned
int
BLOCK_SIZE_x
=
32
;
constexpr
unsigned
int
BLOCK_SIZE_y
=
32
;
int
num_block_x
=
(
N
+
BLOCK_SIZE_x
-
1
)
/
BLOCK_SIZE_x
;
int
num_block_y
=
(
M
+
BLOCK_SIZE_y
-
1
)
/
BLOCK_SIZE_y
;
dim3
block_dim
(
BLOCK_SIZE_x
,
BLOCK_SIZE_y
,
1
);
dim3
grid_dim
(
num_block_x
,
num_block_y
,
1
);
if
(
bias
==
nullptr
)
{
postSym
<
Tdata
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
y
,
y_packed
,
x_packed
,
x_scale
,
w_packed
,
w_scale
,
M
,
K
,
N
);
}
else
{
postSym
<
Tdata
><<<
grid_dim
,
block_dim
,
0
,
stream
>>>
(
y
,
y_packed
,
bias
,
x_packed
,
x_scale
,
w_packed
,
w_scale
,
M
,
K
,
N
);
}
return
INFINI_STATUS_SUCCESS
;
}
infiniStatus_t
Descriptor
::
calculate
(
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
bias
,
const
void
*
a
,
const
void
*
a_scale
,
const
void
*
b
,
const
void
*
b_scale
,
void
*
stream
)
const
{
#if defined(ENABLE_NVIDIA_API) && defined(ENABLE_CUTLASS_API)
auto
sm_version
=
getSMVersion
();
if
(
sm_version
>=
75
&&
sm_version
<
80
)
{
CHECK_DTYPE
(
this
->
_out_dtype
,
INFINI_DTYPE_F16
);
sm75_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm75
,
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
if
(
sm_version
>=
80
&&
sm_version
<
90
)
{
// sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
if
(
sm_version
==
86
||
sm_version
==
89
)
{
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm89_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm89_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
}
else
{
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm80_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm80_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
}
}
else
if
(
sm_version
==
90
)
{
#if defined CUDA_VERSION && CUDA_VERSION >= 12000
// cutlass 3.x
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm90_dispatch_shape
<
cutlass
::
bfloat16_t
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm90_dispatch_shape
<
cutlass
::
half_t
>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
#else
// // fallback to cutlass 2.x
if
(
this
->
_out_dtype
==
INFINI_DTYPE_BF16
)
{
sm80_dispatch_shape
<
cutlass
::
bfloat16_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
else
{
sm80_dispatch_shape
<
cutlass
::
half_t
,
cutlass
::
arch
::
Sm80
,
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>>
(
out
,
a
,
b
,
a_scale
,
b_scale
,
bias
,
_info
.
m
,
_info
.
n
,
_info
.
k
,
_info
.
a_matrix
.
ld
(),
_info
.
b_matrix
.
ld
(),
_info
.
out_matrix
.
ld
(),
stream
);
}
#endif
}
else
{
return
INFINI_STATUS_NOT_IMPLEMENTED
;
}
#elif defined ENABLE_QY_API
#define CALCULATE_LINEAR(BLOCK_SIZE, TDATA) \
launchKernel<BLOCK_SIZE, TDATA>(_info, (TDATA *)out, (const TDATA *)bias, (const int8_t *)a, (const float *)a_scale, (const int8_t *)b, (const float *)b_scale, stream, workspace)
#define CALCULATE_LINEAR_WITH_BLOCK_SIZE(BLOCK_SIZE) \
{ \
if (this->_out_dtype == INFINI_DTYPE_F16) \
return CALCULATE_LINEAR(BLOCK_SIZE, half); \
else if (this->_out_dtype == INFINI_DTYPE_F32) \
return CALCULATE_LINEAR(BLOCK_SIZE, float); \
else if (this->_out_dtype == INFINI_DTYPE_BF16) \
return CALCULATE_LINEAR(BLOCK_SIZE, __nv_bfloat16); \
else \
return INFINI_STATUS_BAD_TENSOR_DTYPE; \
}
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_1024
)
{
CALCULATE_LINEAR_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_1024
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_512
)
{
CALCULATE_LINEAR_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_512
)
}
else
if
(
_opaque
->
internal
->
maxThreadsPerBlock
()
==
CUDA_BLOCK_SIZE_4096
)
{
CALCULATE_LINEAR_WITH_BLOCK_SIZE
(
CUDA_BLOCK_SIZE_4096
)
}
else
{
return
INFINI_STATUS_DEVICE_ARCHITECTURE_NOT_SUPPORTED
;
}
#endif
return
INFINI_STATUS_SUCCESS
;
}
}
// namespace op::i8gemm::nvidia
src/infiniop/ops/scaled_mm/nvidia/int8_gemm_nvidia.cuh
0 → 100644
View file @
8d09630a
#ifndef __INT8_GEMM_NVIDIA_API_H__
#define __INT8_GEMM_NVIDIA_API_H__
#include "../int8_gemm.h"
DESCRIPTOR
(
nvidia
)
#endif // __INT8_GEMM_NVIDIA_API_H__
src/infiniop/ops/scaled_mm/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/int8_gemm.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#include "nvidia/int8_gemm_nvidia.cuh"
#endif
__C
infiniStatus_t
infiniopCreateI8GemmDescriptor
(
infiniopHandle_t
handle
,
infiniopI8GemmDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
out_desc
,
infiniopTensorDescriptor_t
bias_desc
,
infiniopTensorDescriptor_t
a_desc
,
infiniopTensorDescriptor_t
a_scale_desc
,
infiniopTensorDescriptor_t
b_desc
,
infiniopTensorDescriptor_t
b_scale_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::i8gemm::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor **>(desc_ptr), \
out_desc, \
bias_desc, \
a_desc, \
a_scale_desc, \
b_desc, \
b_scale_desc);
switch
(
handle
->
device
)
{
#if defined(ENABLE_NVIDIA_API)
CREATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#if defined(ENABLE_QY_API)
CREATE
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CREATE
}
__C
infiniStatus_t
infiniopGetI8GemmWorkspaceSize
(
infiniopI8GemmDescriptor_t
desc
,
size_t
*
size
)
{
switch
(
desc
->
device_type
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->minWorkspaceSize(); \
return INFINI_STATUS_SUCCESS;
#if defined(ENABLE_NVIDIA_API)
GET
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#if defined(ENABLE_QY_API)
GET
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef GET
}
__C
infiniStatus_t
infiniopI8Gemm
(
infiniopI8GemmDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
out
,
const
void
*
bias
,
const
void
*
a
,
const
void
*
a_scale
,
const
void
*
b
,
const
void
*
b_scale
,
void
*
stream
)
{
#define CACULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, out, bias, a, a_scale, b, b_scale, stream);
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NVIDIA_API)
CACULATE
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#if defined(ENABLE_QY_API)
CACULATE
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef CACULATE
}
__C
infiniStatus_t
infiniopDestroyI8GemmDescriptor
(
infiniopI8GemmDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<op::i8gemm::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#if defined(ENABLE_NVIDIA_API)
DESTROY
(
INFINI_DEVICE_NVIDIA
,
nvidia
)
#endif
#if defined(ENABLE_QY_API)
DESTROY
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
#undef DESTROY
}
src/infiniop/ops/sigmoid/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/sigmoid_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/sigmoid_nvidia.cuh"
#endif
...
...
@@ -34,6 +34,9 @@ __C infiniStatus_t infiniopCreateSigmoidDescriptor(
#ifdef ENABLE_QY_API
CREATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -59,6 +62,10 @@ __C infiniStatus_t infiniopGetSigmoidWorkspaceSize(infiniopSigmoidDescriptor_t d
#ifdef ENABLE_QY_API
GET
(
INFINI_DEVICE_QY
,
nvidia
)
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -91,6 +98,9 @@ __C infiniStatus_t infiniopSigmoid(
#ifdef ENABLE_QY_API
CALCULATE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -118,6 +128,9 @@ infiniopDestroySigmoidDescriptor(infiniopSigmoidDescriptor_t desc) {
#ifdef ENABLE_QY_API
DELETE
(
INFINI_DEVICE_QY
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/silu/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/silu_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/silu_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -46,6 +46,9 @@ __C infiniStatus_t infiniopCreateSiluDescriptor(
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -77,6 +80,10 @@ __C infiniStatus_t infiniopGetSiluWorkspaceSize(infiniopSiluDescriptor_t desc, s
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -115,6 +122,9 @@ __C infiniStatus_t infiniopSilu(
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -148,6 +158,9 @@ infiniopDestroySiluDescriptor(infiniopSiluDescriptor_t desc) {
#ifdef ENABLE_MOORE_API
DELETE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/silu_and_mul/info.h
0 → 100644
View file @
8d09630a
#ifndef __SILU_AND_MUL_INFO_H__
#define __SILU_AND_MUL_INFO_H__
#include "../../../utils.h"
#include "../../tensor.h"
#include <vector>
namespace
op
::
silu_and_mul
{
class
SiluAndMulInfo
{
SiluAndMulInfo
()
=
default
;
public:
infiniDtype_t
dtype
;
size_t
batch_size
;
size_t
out_hidden_dim
;
static
utils
::
Result
<
SiluAndMulInfo
>
create
(
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
auto
dtype
=
y_desc
->
dtype
();
auto
x_shape
=
x_desc
->
shape
();
auto
y_shape
=
y_desc
->
shape
();
auto
ndim
=
x_desc
->
ndim
();
if
(
ndim
!=
y_desc
->
ndim
())
{
return
INFINI_STATUS_BAD_PARAM
;
}
if
(
x_shape
[
ndim
-
1
]
!=
2
*
y_shape
[
ndim
-
1
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
size_t
batch
=
1
;
for
(
int
i
=
0
;
i
<
(
int
)
ndim
-
1
;
++
i
)
{
if
(
x_shape
[
i
]
!=
y_shape
[
i
])
{
return
INFINI_STATUS_BAD_TENSOR_SHAPE
;
}
batch
*=
y_shape
[
i
];
}
return
utils
::
Result
<
SiluAndMulInfo
>
(
SiluAndMulInfo
{
dtype
,
batch
,
y_shape
[
ndim
-
1
]});
}
private:
SiluAndMulInfo
(
infiniDtype_t
dtype
,
size_t
batch
,
size_t
hidden
)
:
dtype
(
dtype
),
batch_size
(
batch
),
out_hidden_dim
(
hidden
)
{}
};
}
// namespace op::silu_and_mul
#endif // __SILU_AND_MUL_INFO_H__
src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.h
0 → 100644
View file @
8d09630a
#ifndef __SILU_ADN_MUL_MOORE_API_H__
#define __SILU_ADN_MUL_MOORE_API_H__
#include "../silu_and_mul.h"
DESCRIPTOR
(
moore
)
#endif // __SILU_ADN_MUL_MOORE_API_H__
src/infiniop/ops/silu_and_mul/moore/silu_and_mul_moore.mu
0 → 100644
View file @
8d09630a
#include "../../../devices/moore/moore_common.h"
#include "../../../devices/moore/moore_handle.h"
#include "silu_and_mul_moore.h"
#include <musa_bf16.h>
#include <memory>
namespace op::silu_and_mul::moore {
struct Descriptor::Opaque {
std::shared_ptr<device::moore::Handle::Internal> internal;
};
Descriptor::~Descriptor() {
delete _opaque;
}
infiniStatus_t Descriptor::create(
infiniopHandle_t handle_,
Descriptor **desc_ptr,
infiniopTensorDescriptor_t y_desc,
infiniopTensorDescriptor_t x_desc) {
if (!desc_ptr) {
return INFINI_STATUS_BAD_PARAM;
}
auto handle = reinterpret_cast<device::moore::Handle *>(handle_);
auto dtype = y_desc->dtype();
CHECK_DTYPE(dtype, INFINI_DTYPE_F16, INFINI_DTYPE_F32, INFINI_DTYPE_BF16);
if (x_desc->dtype() != dtype) {
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
auto result = SiluAndMulInfo::create(y_desc, x_desc);
CHECK_RESULT(result);
auto info = result.take();
*desc_ptr = new Descriptor(
new Opaque{handle->internal()},
std::move(info),
0,
handle->device, handle->device_id);
return INFINI_STATUS_SUCCESS;
}
template <typename T>
infiniStatus_t calculate_impl(
const SiluAndMulInfo &info,
std::shared_ptr<device::moore::Handle::Internal> &internal,
void *y,
const void *x,
void *stream) {
return internal->useMudnn(
(musaStream_t)stream,
[&](::musa::dnn::Handle &mudnn_handle) -> infiniStatus_t {
::musa::dnn::Tensor x_t, y_t;
if constexpr (std::is_same_v<T, half>) {
x_t.SetType(::musa::dnn::Tensor::Type::HALF);
y_t.SetType(::musa::dnn::Tensor::Type::HALF);
} else if constexpr (std::is_same_v<T, __mt_bfloat16>) {
x_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
y_t.SetType(::musa::dnn::Tensor::Type::BFLOAT16);
} else {
x_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
y_t.SetType(::musa::dnn::Tensor::Type::FLOAT);
}
x_t.SetAddr(const_cast<void *>(x));
y_t.SetAddr(y);
// --- Construct 2D dimension information ---
// Explicitly distinguish between Batch and Hidden dimensions
int64_t b = static_cast<int64_t>(info.batch_size);
int64_t h = static_cast<int64_t>(info.out_hidden_dim);
// Input x logical shape is [batch, 2 * hidden]
std::array<int64_t, 2> x_dims = {b, h * 2};
std::array<int64_t, 2> x_strides = {h * 2, 1};
// Output y logical shape is [batch, hidden]
std::array<int64_t, 2> y_dims = {b, h};
std::array<int64_t, 2> y_strides = {h, 1};
x_t.SetNdInfo(2, x_dims.data(), x_strides.data());
y_t.SetNdInfo(2, y_dims.data(), y_strides.data());
// Invoke muDNN SwiGLU
// muDNN will split each row (length 2*h) internally,
// muDNN treats the first h elements of input x as the 'gate'
// and the following h elements as the 'up' projection.
::musa::dnn::SwiGlu swiglu;
swiglu.Run(mudnn_handle, y_t, x_t);
return INFINI_STATUS_SUCCESS;
});
}
infiniStatus_t Descriptor::calculate(
void *workspace, size_t workspace_size,
void *y, const void *x,
void *stream) const {
infiniDtype_t dtype = _info.dtype;
switch (dtype) {
case INFINI_DTYPE_F16:
return calculate_impl<half>(_info, _opaque->internal, y, x, stream);
case INFINI_DTYPE_F32:
return calculate_impl<float>(_info, _opaque->internal, y, x, stream);
case INFINI_DTYPE_BF16:
return calculate_impl<__mt_bfloat16>(_info, _opaque->internal, y, x, stream);
default:
return INFINI_STATUS_BAD_TENSOR_DTYPE;
}
}
} // namespace op::silu_and_mul::moore
src/infiniop/ops/silu_and_mul/operator.cc
0 → 100644
View file @
8d09630a
#include "../../operator.h"
#include "../../handle.h"
#include "infiniop/ops/silu_and_mul.h"
#ifdef ENABLE_MOORE_API
#include "moore/silu_and_mul_moore.h"
#endif
__C
infiniStatus_t
infiniopCreateSiluAndMulDescriptor
(
infiniopHandle_t
handle
,
infiniopSiluAndMulDescriptor_t
*
desc_ptr
,
infiniopTensorDescriptor_t
y_desc
,
infiniopTensorDescriptor_t
x_desc
)
{
#define CREATE(CASE, NAMESPACE) \
case CASE: \
return op::silu_and_mul::NAMESPACE::Descriptor::create( \
handle, \
reinterpret_cast<op::silu_and_mul::NAMESPACE::Descriptor **>(desc_ptr), \
y_desc, \
x_desc);
switch
(
handle
->
device
)
{
#ifdef ENABLE_MOORE_API
CREATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopGetSiluAndMulWorkspaceSize
(
infiniopSiluAndMulDescriptor_t
desc
,
size_t
*
size
)
{
#define GET(CASE, NAMESPACE) \
case CASE: \
*size = reinterpret_cast<const op::silu_and_mul::NAMESPACE::Descriptor *>(desc)->workspaceSize(); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_MOORE_API
GET
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopSiluAndMul
(
infiniopSiluAndMulDescriptor_t
desc
,
void
*
workspace
,
size_t
workspace_size
,
void
*
y
,
const
void
*
x
,
void
*
stream
)
{
#define CALCULATE(CASE, NAMESPACE) \
case CASE: \
return reinterpret_cast<const op::silu_and_mul::NAMESPACE::Descriptor *>(desc)->calculate( \
workspace, workspace_size, y, x, stream);
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_MOORE_API
CALCULATE
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
__C
infiniStatus_t
infiniopDestroySiluAndMulDescriptor
(
infiniopSiluAndMulDescriptor_t
desc
)
{
#define DESTROY(CASE, NAMESPACE) \
case CASE: \
delete reinterpret_cast<const op::silu_and_mul::NAMESPACE::Descriptor *>(desc); \
return INFINI_STATUS_SUCCESS;
switch
(
desc
->
device_type
)
{
#ifdef ENABLE_MOORE_API
DESTROY
(
INFINI_DEVICE_MOORE
,
moore
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
src/infiniop/ops/silu_and_mul/silu_and_mul.h
0 → 100644
View file @
8d09630a
#ifndef SILU_AND_MUL_H
#define SILU_AND_MUL_H
#include "../../operator.h"
#include "info.h"
#define DESCRIPTOR(NAMESPACE) \
\
namespace op::silu_and_mul::NAMESPACE { \
class Descriptor final : public InfiniopDescriptor { \
struct Opaque; \
Opaque *_opaque; \
SiluAndMulInfo _info; \
size_t _workspace_size; \
\
Descriptor( \
Opaque *opaque, \
SiluAndMulInfo info, \
size_t workspace_size, \
infiniDevice_t device_type, \
int device_id) \
: InfiniopDescriptor{device_type, device_id}, \
_opaque(opaque), \
_info(info), \
_workspace_size(workspace_size) {} \
\
public: \
~Descriptor(); \
\
size_t workspaceSize() const { return _workspace_size; } \
\
static infiniStatus_t create( \
infiniopHandle_t handle, \
Descriptor **desc_ptr, \
infiniopTensorDescriptor_t y_desc, \
infiniopTensorDescriptor_t x_desc); \
\
infiniStatus_t calculate( \
void *workspace, size_t workspace_size, \
void *y, \
const void *x, \
void *stream) const; \
}; \
}
#endif // SILU_AND_MUL_H
src/infiniop/ops/softmax/operator.cc
View file @
8d09630a
...
...
@@ -2,7 +2,7 @@
#include "../../handle.h"
#include "infiniop/ops/softmax.h"
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API) || defined(ENABLE_HYGON_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/softmax_nvidia.cuh"
#endif
...
...
@@ -33,6 +33,9 @@ __C infiniStatus_t infiniopCreateSoftmaxDescriptor(
#endif
#ifdef ENABLE_HYGON_API
CREATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -57,6 +60,9 @@ __C infiniStatus_t infiniopGetSoftmaxWorkspaceSize(infiniopSoftmaxDescriptor_t d
#endif
#ifdef ENABLE_HYGON_API
GET
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -86,6 +92,9 @@ __C infiniStatus_t infiniopSoftmax(
#endif
#ifdef ENABLE_HYGON_API
CALCULATE
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -110,6 +119,9 @@ __C infiniStatus_t infiniopDestroySoftmaxDescriptor(infiniopSoftmaxDescriptor_t
#endif
#ifdef ENABLE_HYGON_API
DESTROY
(
INFINI_DEVICE_HYGON
,
nvidia
);
#endif
#ifdef ENABLE_ALI_API
DESTROY
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
}
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
src/infiniop/ops/softplus/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/softplus_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/softplus_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -49,6 +49,10 @@ __C infiniStatus_t infiniopCreateSoftplusDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -82,6 +86,10 @@ __C infiniStatus_t infiniopGetSoftplusWorkspaceSize(infiniopSoftplusDescriptor_t
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -123,6 +131,10 @@ __C infiniStatus_t infiniopSoftplus(
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
@@ -158,6 +170,10 @@ infiniopDestroySoftplusDescriptor(infiniopSoftplusDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
}
...
...
src/infiniop/ops/sub/operator.cc
View file @
8d09630a
...
...
@@ -5,7 +5,7 @@
#ifdef ENABLE_CPU_API
#include "cpu/sub_cpu.h"
#endif
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
#if defined(ENABLE_NVIDIA_API) || defined(ENABLE_ILUVATAR_API) || defined(ENABLE_QY_API)
|| defined(ENABLE_ALI_API)
#include "nvidia/sub_nvidia.cuh"
#endif
#ifdef ENABLE_METAX_API
...
...
@@ -51,6 +51,9 @@ __C infiniStatus_t infiniopCreateSubDescriptor(
#ifdef ENABLE_KUNLUN_API
CREATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
CREATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -85,6 +88,9 @@ __C infiniStatus_t infiniopGetSubWorkspaceSize(infiniopSubDescriptor_t desc, siz
#ifdef ENABLE_KUNLUN_API
GET
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
GET
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -128,6 +134,9 @@ __C infiniStatus_t infiniopSub(
#ifdef ENABLE_KUNLUN_API
CALCULATE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
CALCULATE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
@@ -164,6 +173,9 @@ infiniopDestroySubDescriptor(infiniopSubDescriptor_t desc) {
#ifdef ENABLE_KUNLUN_API
DELETE
(
INFINI_DEVICE_KUNLUN
,
kunlun
);
#endif
#ifdef ENABLE_ALI_API
DELETE
(
INFINI_DEVICE_ALI
,
nvidia
);
#endif
default:
return
INFINI_STATUS_DEVICE_TYPE_NOT_SUPPORTED
;
...
...
Prev
1
…
11
12
13
14
15
16
17
18
19
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment