Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
change
sglang
Commits
3ee62235
Unverified
Commit
3ee62235
authored
Jan 31, 2025
by
Yineng Zhang
Committed by
GitHub
Jan 31, 2025
Browse files
revert the MoE dependence (#3230)
parent
9829e77e
Changes
94
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
0 additions
and
5074 deletions
+0
-5074
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp
+0
-105
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h
+0
-48
sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h
sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h
+0
-87
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp
...sions/include/cutlass_extensions/arch/copy_red_global.hpp
+0
-352
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h
.../cutlass_extensions/include/cutlass_extensions/arch/mma.h
+0
-120
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h
...extensions/include/cutlass_extensions/compute_occupancy.h
+0
-88
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp
..._extensions/epilogue/collective/epilogue_moe_finalize.hpp
+0
-550
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h
...de/cutlass_extensions/epilogue/thread/fused_activations.h
+0
-105
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
...ons/epilogue/threadblock/epilogue_per_row_per_col_scale.h
+0
-352
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
...xtensions/epilogue/threadblock/epilogue_tensor_op_int32.h
+0
-282
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
..._extensions/include/cutlass_extensions/epilogue_helpers.h
+0
-141
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl
...ions/gemm/collective/builders/sm90_gmma_builder_gated.inl
+0
-221
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp
...s_extensions/gemm/collective/collective_builder_gated.hpp
+0
-58
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp
...tlass_extensions/gemm/collective/collective_mma_gated.hpp
+0
-59
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp
...collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp
+0
-642
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp
...ective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp
+0
-665
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
...tlass_extensions/gemm/device/gemm_universal_base_compat.h
+0
-438
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h
...lude/cutlass_extensions/gemm/device/splitk_gemm_grouped.h
+0
-542
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
.../cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
+0
-162
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h
...lude/cutlass_extensions/gemm/kernel/default_int8_traits.h
+0
-57
No files found.
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.cpp
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tensorrt_llm/common/tllmException.h"
#include "tensorrt_llm/common/stringUtils.h"
#include <cstdlib>
#if !defined(_MSC_VER)
#include <cxxabi.h>
#include <dlfcn.h>
#include <execinfo.h>
#endif
#include <sstream>
namespace
tensorrt_llm
::
common
{
namespace
{
int
constexpr
VOID_PTR_SZ
=
2
+
sizeof
(
void
*
)
*
2
;
}
#if !defined(_MSC_VER)
TllmException
::
TllmException
(
char
const
*
file
,
std
::
size_t
line
,
std
::
string
const
&
msg
)
:
std
::
runtime_error
{
""
}
{
mNbFrames
=
backtrace
(
mCallstack
.
data
(),
MAX_FRAMES
);
auto
const
trace
=
getTrace
();
std
::
runtime_error
::
operator
=
(
std
::
runtime_error
{
fmtstr
(
"%s (%s:%zu)
\n
%s"
,
msg
.
c_str
(),
file
,
line
,
trace
.
c_str
())});
}
#else
TllmException
::
TllmException
(
char
const
*
file
,
std
::
size_t
line
,
std
::
string
const
&
msg
)
:
mNbFrames
{}
,
std
::
runtime_error
{
fmtstr
(
"%s (%s:%zu)"
,
msg
.
c_str
(),
file
,
line
)}
{
}
#endif
TllmException
::~
TllmException
()
noexcept
=
default
;
std
::
string
TllmException
::
getTrace
()
const
{
#if defined(_MSC_VER)
return
""
;
#else
auto
const
trace
=
backtrace_symbols
(
mCallstack
.
data
(),
mNbFrames
);
std
::
ostringstream
buf
;
for
(
auto
i
=
1
;
i
<
mNbFrames
;
++
i
)
{
Dl_info
info
;
if
(
dladdr
(
mCallstack
[
i
],
&
info
)
&&
info
.
dli_sname
)
{
auto
const
clearName
=
demangle
(
info
.
dli_sname
);
buf
<<
fmtstr
(
"%-3d %*p %s + %zd"
,
i
,
VOID_PTR_SZ
,
mCallstack
[
i
],
clearName
.
c_str
(),
static_cast
<
char
*>
(
mCallstack
[
i
])
-
static_cast
<
char
*>
(
info
.
dli_saddr
));
}
else
{
buf
<<
fmtstr
(
"%-3d %*p %s"
,
i
,
VOID_PTR_SZ
,
mCallstack
[
i
],
trace
[
i
]);
}
if
(
i
<
mNbFrames
-
1
)
buf
<<
std
::
endl
;
}
if
(
mNbFrames
==
MAX_FRAMES
)
buf
<<
std
::
endl
<<
"[truncated]"
;
std
::
free
(
trace
);
return
buf
.
str
();
#endif
}
std
::
string
TllmException
::
demangle
(
char
const
*
name
)
{
#if defined(_MSC_VER)
return
name
;
#else
std
::
string
clearName
{
name
};
auto
status
=
-
1
;
auto
const
demangled
=
abi
::
__cxa_demangle
(
name
,
nullptr
,
nullptr
,
&
status
);
if
(
status
==
0
)
{
clearName
=
demangled
;
std
::
free
(
demangled
);
}
return
clearName
;
#endif
}
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/tllmException.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <array>
#include <cstddef>
#include <stdexcept>
#include <string>
#define NEW_TLLM_EXCEPTION(...) \
tensorrt_llm::common::TllmException(__FILE__, __LINE__, tensorrt_llm::common::fmtstr(__VA_ARGS__))
namespace
tensorrt_llm
::
common
{
class
TllmException
:
public
std
::
runtime_error
{
public:
static
auto
constexpr
MAX_FRAMES
=
128
;
explicit
TllmException
(
char
const
*
file
,
std
::
size_t
line
,
std
::
string
const
&
msg
);
~
TllmException
()
noexcept
override
;
[[
nodiscard
]]
std
::
string
getTrace
()
const
;
static
std
::
string
demangle
(
char
const
*
name
);
private:
std
::
array
<
void
*
,
MAX_FRAMES
>
mCallstack
{};
int
mNbFrames
;
};
}
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/common/workspace.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 1993-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cstddef>
#include <cstdint>
namespace
tensorrt_llm
::
common
{
std
::
uintptr_t
constexpr
kCudaMemAlign
=
128
;
inline
int8_t
*
alignPtr
(
int8_t
*
ptr
,
uintptr_t
to
)
{
uintptr_t
addr
=
(
uintptr_t
)
ptr
;
if
(
addr
%
to
)
{
addr
+=
to
-
addr
%
to
;
}
return
(
int8_t
*
)
addr
;
}
constexpr
size_t
alignSize
(
size_t
size
,
size_t
to
)
{
if
((
size
%
to
)
!=
0U
)
{
size
+=
to
-
size
%
to
;
}
return
size
;
}
inline
int8_t
*
nextWorkspacePtrCommon
(
int8_t
*
ptr
,
uintptr_t
previousWorkspaceSize
,
uintptr_t
const
alignment
)
{
uintptr_t
addr
=
(
uintptr_t
)
ptr
;
addr
+=
previousWorkspaceSize
;
return
alignPtr
((
int8_t
*
)
addr
,
alignment
);
}
inline
int8_t
*
nextWorkspacePtr
(
int8_t
*
ptr
,
uintptr_t
previousWorkspaceSize
)
{
return
nextWorkspacePtrCommon
(
ptr
,
previousWorkspaceSize
,
kCudaMemAlign
);
}
inline
int8_t
*
nextWorkspacePtr
(
int8_t
*
const
base
,
uintptr_t
&
offset
,
uintptr_t
const
size
,
uintptr_t
const
alignment
=
kCudaMemAlign
)
{
uintptr_t
curr_offset
=
offset
;
uintptr_t
next_offset
=
curr_offset
+
((
size
+
alignment
-
1
)
/
alignment
)
*
alignment
;
int8_t
*
newptr
=
size
==
0
?
nullptr
:
base
+
curr_offset
;
offset
=
next_offset
;
return
newptr
;
}
inline
int8_t
*
nextWorkspacePtrWithAlignment
(
int8_t
*
ptr
,
uintptr_t
previousWorkspaceSize
,
uintptr_t
const
alignment
=
kCudaMemAlign
)
{
return
nextWorkspacePtrCommon
(
ptr
,
previousWorkspaceSize
,
alignment
);
}
inline
size_t
calculateTotalWorkspaceSize
(
size_t
const
*
workspaces
,
int
count
,
uintptr_t
const
alignment
=
kCudaMemAlign
)
{
size_t
total
=
0
;
for
(
int
i
=
0
;
i
<
count
;
i
++
)
{
total
+=
workspaces
[
i
];
if
(
workspaces
[
i
]
%
alignment
)
{
total
+=
alignment
-
(
workspaces
[
i
]
%
alignment
);
}
}
return
total
;
}
};
// namespace tensorrt_llm::common
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/copy_red_global.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include <cute/config.hpp>
#include <cute/arch/util.hpp>
#include <cute/atom/copy_traits.hpp>
#include <cute/numeric/numeric_types.hpp>
// Config
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) && (__CUDACC_VER_MAJOR__ >= 10))
#define CUTE_ARCH_RED_F16_SM70_ENABLED
#endif
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12))
#define CUTE_ARCH_RED_VEC_SM90_ENABLED
#define CUTE_ARCH_RED_BF16_SM90_ENABLED
#endif
namespace
cute
{
//////////////////////////////////
// Wrapper around CUDA's atomicAdd
//////////////////////////////////
template
<
class
T
>
struct
TypedAtomicAdd
{
using
SRegisters
=
T
[
1
];
using
DRegisters
=
T
[
1
];
CUTE_HOST_DEVICE
static
constexpr
void
copy
(
T
const
&
src
,
T
&
dst
)
{
atomicAdd
(
&
dst
,
src
);
}
};
template
<
class
T
>
struct
Copy_Traits
<
TypedAtomicAdd
<
T
>>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
Int
<
sizeof_bits
<
T
>::
value
>>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
Int
<
sizeof_bits
<
T
>::
value
>>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
//////////////////////////////////
// F16 ADD PTX
//////////////////////////////////
struct
SM70_RED_ADD_NOFTZ_F16
{
using
SRegisters
=
uint16_t
[
1
];
using
DRegisters
=
uint16_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint16_t
const
&
src0
,
uint16_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm
volatile
(
"red.global.add.noftz.f16 [%0], %1;
\n
"
::
"l"
(
&
gmem_dst
),
"h"
(
src0
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM70_RED_ADD_NOFTZ_F16
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_16
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_16
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
struct
SM70_RED_ADD_NOFTZ_F16x2
{
using
SRegisters
=
uint32_t
[
1
];
using
DRegisters
=
uint32_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint32_t
const
&
src0
,
uint32_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_F16_SM70_ENABLED)
asm
volatile
(
"red.global.add.noftz.f16x2 [%0], %1;
\n
"
::
"l"
(
&
gmem_dst
),
"r"
(
src0
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.f16 without CUTE_ARCH_RED_F16_SM70_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM70_RED_ADD_NOFTZ_F16x2
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_32
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_32
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
struct
SM90_RED_ADD_NOFTZ_F16x2_V2
{
using
SRegisters
=
uint32_t
[
2
];
using
DRegisters
=
uint64_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint32_t
const
&
src0
,
uint32_t
const
&
src1
,
uint64_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm
volatile
(
"red.global.add.noftz.v2.f16x2 [%0], {%1, %2};
\n
"
::
"l"
(
&
gmem_dst
),
"r"
(
src0
),
"r"
(
src1
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM90_RED_ADD_NOFTZ_F16x2_V2
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_64
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_64
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
struct
SM90_RED_ADD_NOFTZ_F16x2_V4
{
using
SRegisters
=
uint32_t
[
4
];
using
DRegisters
=
uint128_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint32_t
const
&
src0
,
uint32_t
const
&
src1
,
uint32_t
const
&
src2
,
uint32_t
const
&
src3
,
uint128_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_VEC_SM90_ENABLED)
asm
volatile
(
"red.global.add.noftz.v4.f16x2 [%0], {%1, %2, %3, %4};
\n
"
::
"l"
(
&
gmem_dst
),
"r"
(
src0
),
"r"
(
src1
),
"r"
(
src2
),
"r"
(
src3
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.vX without CUTE_ARCH_RED_VEC_SM90_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM90_RED_ADD_NOFTZ_F16x2_V4
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_128
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_128
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
//////////////////////////////////
// BF16 ADD PTX
//////////////////////////////////
struct
SM90_RED_ADD_NOFTZ_BF16
{
using
SRegisters
=
uint16_t
[
1
];
using
DRegisters
=
uint16_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint16_t
const
&
src0
,
uint16_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm
volatile
(
"red.global.add.noftz.bf16 [%0], %1;
\n
"
::
"l"
(
&
gmem_dst
),
"h"
(
src0
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM90_RED_ADD_NOFTZ_BF16
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_16
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_16
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
//////////////////////////////////
struct
SM90_RED_ADD_NOFTZ_BF16x2
{
using
SRegisters
=
uint32_t
[
1
];
using
DRegisters
=
uint32_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint32_t
const
&
src0
,
uint32_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm
volatile
(
"red.global.add.noftz.bf16x2 [%0], %1;
\n
"
::
"l"
(
&
gmem_dst
),
"r"
(
src0
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM90_RED_ADD_NOFTZ_BF16x2
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_32
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_32
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
//////////////////////////////////
struct
SM90_RED_ADD_NOFTZ_BF16x2_V2
{
using
SRegisters
=
uint32_t
[
2
];
using
DRegisters
=
uint64_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint32_t
const
&
src0
,
uint32_t
const
&
src1
,
uint64_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm
volatile
(
"red.global.add.noftz.v2.bf16x2 [%0], {%1, %2};
\n
"
::
"l"
(
&
gmem_dst
),
"r"
(
src0
),
"r"
(
src1
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM90_RED_ADD_NOFTZ_BF16x2_V2
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_64
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_64
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
//////////////////////////////////
struct
SM90_RED_ADD_NOFTZ_BF16x2_V4
{
using
SRegisters
=
uint32_t
[
4
];
using
DRegisters
=
uint128_t
[
1
];
CUTE_HOST_DEVICE
static
void
copy
(
uint32_t
const
&
src0
,
uint32_t
const
&
src1
,
uint32_t
const
&
src2
,
uint32_t
const
&
src3
,
uint128_t
&
gmem_dst
)
{
#if defined(CUTE_ARCH_RED_BF16_SM90_ENABLED)
asm
volatile
(
"red.global.add.noftz.v4.bf16x2 [%0], {%1, %2, %3, %4};
\n
"
::
"l"
(
&
gmem_dst
),
"r"
(
src0
),
"r"
(
src1
),
"r"
(
src2
),
"r"
(
src3
));
#else
CUTE_INVALID_CONTROL_PATH
(
"Trying to use red.global.bf16 without CUTE_ARCH_RED_BF16_SM90_ENABLED."
);
#endif
}
};
template
<
>
struct
Copy_Traits
<
SM90_RED_ADD_NOFTZ_BF16x2_V4
>
{
// Logical thread id to thread idx (one-thread)
using
ThrID
=
Layout
<
_1
>
;
// Map from (src-thr,src-val) to bit
using
SrcLayout
=
Layout
<
Shape
<
_1
,
_128
>>
;
// Map from (dst-thr,dst-val) to bit
using
DstLayout
=
Layout
<
Shape
<
_1
,
_128
>>
;
// Reference map from (thr,val) to bit
using
RefLayout
=
SrcLayout
;
};
//////////////////////////////////
}
// end namespace cute
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/arch/mma.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Templates exposing architecture support for multiply-add operations
*/
#pragma once
#include "cutlass_extensions/weight_only_quant_op.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
arch
{
// Tag which triggers MMA which will trigger
struct
OpMultiplyAddDequantizeInterleavedBToA
;
/*
Below we have extra tags to signal what kind of dequantization we want to do
(per col, scale only fine grained, finegrained with zero). This still lets us
the existing template infrastructure (incl. that in CUTLASS). However, we
split out the template below into OpMultiplyAddDequantizeInterleavedBToA along
with the quantization op before instantiating the GEMM pieces.
Note that this is somewhat of a hack, but it SIGNIFICANTLY reduces the amount of
code we need to duplicate.
*/
struct
OpMultiplyAddDequantizeInterleavedBToA_percol_scale
;
struct
OpMultiplyAddDequantizeInterleavedBToA_fine_scale
;
struct
OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias
;
// The default just forwards the original operator
template
<
typename
MmaOp
,
WeightOnlyQuantOp
QuantOp_
>
struct
TagOperator
{
using
TaggedOperator
=
MmaOp
;
};
// Specializations below attach more information to the operator
template
<
>
struct
TagOperator
<
OpMultiplyAddDequantizeInterleavedBToA
,
WeightOnlyQuantOp
::
PER_COLUMN_SCALE_ONLY
>
{
using
TaggedOperator
=
OpMultiplyAddDequantizeInterleavedBToA_percol_scale
;
};
template
<
>
struct
TagOperator
<
OpMultiplyAddDequantizeInterleavedBToA
,
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_ONLY
>
{
using
TaggedOperator
=
OpMultiplyAddDequantizeInterleavedBToA_fine_scale
;
};
template
<
>
struct
TagOperator
<
OpMultiplyAddDequantizeInterleavedBToA
,
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_AND_ZEROS
>
{
using
TaggedOperator
=
OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias
;
};
// Here we instantiate some structs to "detag" the tagged operator. It splits it back to the original
// operator + the extra information. If no extra info was tagged, the dequant op per column scaling
// as a default.
template
<
typename
TaggedMmaOp
>
struct
DetagOperator
{
using
Operator
=
TaggedMmaOp
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
WeightOnlyQuantOp
::
PER_COLUMN_SCALE_ONLY
;
};
template
<
>
struct
DetagOperator
<
OpMultiplyAddDequantizeInterleavedBToA_percol_scale
>
{
using
Operator
=
OpMultiplyAddDequantizeInterleavedBToA
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
WeightOnlyQuantOp
::
PER_COLUMN_SCALE_ONLY
;
};
template
<
>
struct
DetagOperator
<
OpMultiplyAddDequantizeInterleavedBToA_fine_scale
>
{
using
Operator
=
OpMultiplyAddDequantizeInterleavedBToA
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_ONLY
;
};
template
<
>
struct
DetagOperator
<
OpMultiplyAddDequantizeInterleavedBToA_fine_scalebias
>
{
using
Operator
=
OpMultiplyAddDequantizeInterleavedBToA
;
static
constexpr
WeightOnlyQuantOp
QuantOp
=
WeightOnlyQuantOp
::
FINEGRAINED_SCALE_AND_ZEROS
;
};
}
// namespace arch
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/compute_occupancy.h
deleted
100644 → 0
View file @
9829e77e
/*
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include <cuda_runtime_api.h>
#include "cutlass/device_kernel.h"
#include "tensorrt_llm/common/cudaUtils.h"
namespace
tensorrt_llm
{
namespace
cutlass_extensions
{
template
<
typename
GemmKernel
,
bool
enable_cutlass_3x
=
false
>
inline
int
compute_occupancy_for_kernel
()
{
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
if
(
smem_size
>
(
48
<<
10
))
{
cudaFuncAttributes
attr
;
int
device
=
0
;
int
max_smem_per_block
=
0
;
tensorrt_llm
::
common
::
check_cuda_error
(
cudaGetDevice
(
&
device
));
tensorrt_llm
::
common
::
check_cuda_error
(
cudaDeviceGetAttribute
(
&
max_smem_per_block
,
cudaDevAttrMaxSharedMemoryPerBlockOptin
,
device
));
if
constexpr
(
enable_cutlass_3x
)
{
tensorrt_llm
::
common
::
check_cuda_error
(
cudaFuncGetAttributes
(
&
attr
,
cutlass
::
device_kernel
<
GemmKernel
>
));
}
else
{
tensorrt_llm
::
common
::
check_cuda_error
(
cudaFuncGetAttributes
(
&
attr
,
cutlass
::
Kernel
<
GemmKernel
>
));
}
if
(
smem_size
+
attr
.
sharedSizeBytes
>=
static_cast
<
size_t
>
(
max_smem_per_block
))
{
// This should mean that
// cudaFuncSetAttribute(cutlass::Kernel<GemmKernel>, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)
// wouldn't work. In that case, we return an occupancy of 0. This will cause the heuristic to ignore this
// configuration.
return
0
;
}
if
constexpr
(
enable_cutlass_3x
)
{
tensorrt_llm
::
common
::
check_cuda_error
(
cudaFuncSetAttribute
(
cutlass
::
device_kernel
<
GemmKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
else
{
tensorrt_llm
::
common
::
check_cuda_error
(
cudaFuncSetAttribute
(
cutlass
::
Kernel
<
GemmKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
int
max_active_blocks
=
-
1
;
if
constexpr
(
enable_cutlass_3x
)
{
tensorrt_llm
::
common
::
check_cuda_error
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
cutlass
::
device_kernel
<
GemmKernel
>
,
128
*
(
GemmKernel
::
NumLoadWarpGroups
+
GemmKernel
::
NumMmaWarpGroups
),
smem_size
));
}
else
{
tensorrt_llm
::
common
::
check_cuda_error
(
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
cutlass
::
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
));
}
return
max_active_blocks
;
}
}
// namespace cutlass_extensions
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/collective/epilogue_moe_finalize.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing elementwise operations used by epilogues.
*/
#pragma once
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/collective/detail.hpp"
#include "cutlass/fast_math.h"
#include "cute/numeric/numeric_types.hpp"
#include "cute/tensor.hpp"
#include "cutlass/trace.h"
#include "cutlass_extensions/arch/copy_red_global.hpp"
#include "cutlass_extensions/util/gather_tensor.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
epilogue
{
namespace
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
class
StrideC_
,
class
ElementD_
,
class
StrideD_
,
class
ThreadEpilogueOp_
,
class
ElementBias
,
class
StrideBias
,
class
ElementScale
,
class
StrideScale
,
class
EpilogueTile
,
class
SmemLayoutAtomD
,
class
CopyOpR2S
,
class
CopyOpS2R
,
class
CopyOpR2G
>
class
EpilogueMoeFusedFinalize
{
public:
using
EpilogueSchedule
=
PtrArrayNoSmemWarpSpecialized
;
using
DispatchPolicy
=
PtrArrayNoSmemWarpSpecialized
;
using
ThreadEpilogueOp
=
ThreadEpilogueOp_
;
using
ElementOutput
=
typename
ThreadEpilogueOp
::
ElementOutput
;
using
ElementAccumulator
=
typename
ThreadEpilogueOp
::
ElementAccumulator
;
using
ElementCompute
=
typename
ThreadEpilogueOp
::
ElementCompute
;
using
ElementIntermediate
=
typename
ThreadEpilogueOp
::
ElementD
;
using
ElementC
=
typename
ThreadEpilogueOp
::
ElementC
;
using
StrideC
=
StrideC_
;
using
InternalStrideC
=
cute
::
remove_pointer_t
<
StrideC
>
;
using
ElementD
=
ElementD_
;
using
StrideD
=
StrideD_
;
using
InternalStrideD
=
cute
::
remove_pointer_t
<
StrideD
>
;
static_assert
(
!
is_same_v
<
InternalStrideC
,
StrideC
>
,
"Stride C must be a pointer"
);
static_assert
(
is_same_v
<
InternalStrideD
,
StrideD
>
,
"Stride D must not be a pointer"
);
using
CopyAtomR2S
=
Copy_Atom
<
CopyOpR2S
,
ElementAccumulator
>
;
using
CopyAtomS2R
=
Copy_Atom
<
CopyOpS2R
,
ElementAccumulator
>
;
using
CopyAtomR2G
=
Copy_Atom
<
CopyOpR2G
,
ElementD
>
;
static
constexpr
int
AlignmentD
=
CopyAtomR2G
::
NumValSrc
;
using
SmemLayoutD
=
decltype
(
tile_to_shape
(
SmemLayoutAtomD
{},
EpilogueTile
{}));
constexpr
static
size_t
SmemAlignmentD
=
cutlass
::
detail
::
alignment_for_swizzle
(
SmemLayoutD
{});
struct
SharedStorage
{
alignas
(
SmemAlignmentD
)
cute
::
ArrayEngine
<
ElementAccumulator
,
cosize_v
<
SmemLayoutD
>>
smem_D
;
};
struct
TensorMapStorage
{
};
struct
Arguments
{
typename
ThreadEpilogueOp
::
Params
thread
{};
ElementC
const
**
ptr_C
{};
StrideC
dC
{};
ElementD
*
ptr_D
{};
StrideD
dD
{};
ElementBias
const
*
ptr_bias
;
StrideBias
dBias
{};
ElementScale
const
*
ptr_scale
;
StrideScale
dScale
{};
int64_t
const
*
group_offset
{};
int32_t
const
*
scatter_index
{};
cutlass
::
FastDivmod
num_rows_in_final_output
;
};
using
Params
=
Arguments
;
//
// Methods
//
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
,
Arguments
const
&
args
,
[[
maybe_unused
]]
void
*
workspace
)
{
return
args
;
}
template
<
class
ProblemShape
>
static
size_t
get_workspace_size
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
int
sm_count
=
0
)
{
return
0
;
}
template
<
class
ProblemShape
>
static
cutlass
::
Status
initialize_workspace
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
,
CudaHostAdapter
*
cuda_adapter
=
nullptr
)
{
return
cutlass
::
Status
::
kSuccess
;
}
template
<
class
ProblemShape
>
CUTLASS_HOST_DEVICE
static
bool
can_implement
(
[[
maybe_unused
]]
ProblemShape
problem_shape
,
[[
maybe_unused
]]
Arguments
const
&
args
)
{
bool
implementable
=
true
;
if
(
problem_shape
.
is_host_problem_shape_available
())
{
// Check alignment for all problem sizes
for
(
int
i
=
0
;
i
<
problem_shape
.
groups
();
i
++
)
{
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
.
get_host_problem_shape
(
i
),
1
);
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
implementable
=
implementable
&&
cutlass
::
detail
::
check_alignment
<
AlignmentD
>
(
cute
::
make_shape
(
M
,
N
,
L
),
InternalStrideD
{});
}
}
if
(
!
implementable
)
{
CUTLASS_TRACE_HOST
(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for selected global "
"reduction instruction.
\n
"
);
}
return
implementable
;
}
CUTLASS_HOST_DEVICE
EpilogueMoeFusedFinalize
(
Params
const
&
params_
)
:
params
(
params_
)
{
}
CUTLASS_DEVICE
bool
is_source_needed
()
{
// For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta.
return
params
.
ptr_C
!=
nullptr
&&
(
params
.
thread
.
beta_ptr_array
||
params
.
thread
.
beta_ptr
||
params
.
thread
.
beta
!=
0
);
}
template
<
class
ProblemShapeMNKL
,
class
BlockShapeMNK
,
class
BlockCoordMNKL
,
class
FrgEngine
,
class
FrgLayout
,
class
TiledMma
,
class
ResidueMNK
>
CUTLASS_HOST_DEVICE
void
operator
()(
ProblemShapeMNKL
problem_shape_mnkl
,
BlockShapeMNK
blk_shape_MNK
,
BlockCoordMNKL
blk_coord_mnkl
,
cute
::
Tensor
<
FrgEngine
,
FrgLayout
>
const
&
accumulators
,
TiledMma
tiled_mma
,
ResidueMNK
residue_mnk
,
int
thread_idx
,
[[
maybe_unused
]]
char
*
smem_buf
)
{
using
namespace
cute
;
using
X
=
Underscore
;
static_assert
(
rank
(
ProblemShapeMNKL
{})
==
4
,
"ProblemShapeMNKL must be rank 4"
);
static_assert
(
is_static
<
BlockShapeMNK
>::
value
,
"ThreadBlock tile shape must be static"
);
static_assert
(
rank
(
BlockShapeMNK
{})
==
3
,
"BlockShapeMNK must be rank 3"
);
static_assert
(
rank
(
BlockCoordMNKL
{})
==
4
,
"BlockCoordMNKL must be rank 3"
);
auto
synchronize
=
[
&
]()
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
size
(
TiledMma
{}),
cutlass
::
arch
::
ReservedNamedBarriers
::
EpilogueBarrier
);
};
// Separate out problem shape for convenience
auto
M
=
get
<
0
>
(
problem_shape_mnkl
);
auto
N
=
get
<
1
>
(
problem_shape_mnkl
);
auto
L
=
get
<
3
>
(
problem_shape_mnkl
);
auto
mma_tile_m
=
tile_size
<
0
>
(
tiled_mma
);
auto
mma_tile_n
=
tile_size
<
1
>
(
tiled_mma
);
auto
epi_tile_m
=
size
<
0
>
(
EpilogueTile
{});
auto
epi_tile_n
=
size
<
1
>
(
EpilogueTile
{});
CUTE_STATIC_ASSERT
(
epi_tile_m
%
mma_tile_m
==
0
,
"MMA_TILE_M must divide EPI_TILE_M"
);
CUTE_STATIC_ASSERT
(
mma_tile_n
%
epi_tile_n
==
0
,
"EPI_TILE_N must divide MMA_TILE_N"
);
// Batches are managed by using appropriate pointers to C and D matrices
int32_t
const
mock_L
=
1
;
int32_t
const
mock_l_coord
=
0
;
// Slice to get the tile this CTA is responsible for
auto
[
m_coord
,
n_coord
,
k_coord
,
l_coord
]
=
blk_coord_mnkl
;
// If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups.
// If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups,
// we get the correct alpha/beta values for the current batch/group using group index.
ThreadEpilogueOp
epilogue_op
(
params
.
thread
,
l_coord
);
SharedStorage
&
storage
=
*
reinterpret_cast
<
SharedStorage
*>
(
smem_buf
);
Tensor
sD_
=
make_tensor
(
make_smem_ptr
(
storage
.
smem_D
.
begin
()),
SmemLayoutD
{});
Tensor
sD
=
as_position_independent_swizzle_tensor
(
sD_
);
// Function to scatter output rows
auto
&
num_rows
=
params
.
num_rows_in_final_output
;
auto
read_scatter_map
=
IndexedGather
(
make_gmem_ptr
(
params
.
scatter_index
+
params
.
group_offset
[
l_coord
]));
auto
get_scatter_idx
=
[
&
](
auto
i
)
{
auto
scatter
=
read_scatter_map
(
i
);
int
quot
,
rem
;
num_rows
(
quot
,
rem
,
scatter
);
return
rem
;
};
// Represent the full output tensor
ElementC
const
*
ptr_C
=
epilogue_op
.
is_source_needed
()
?
params
.
ptr_C
[
l_coord
]
:
nullptr
;
auto
dC
=
epilogue_op
.
is_source_needed
()
?
params
.
dC
[
l_coord
]
:
InternalStrideC
{};
Tensor
mC_mnl
=
make_tensor
(
make_gmem_ptr
(
ptr_C
),
make_shape
(
M
,
N
,
mock_L
),
dC
);
// (m,n,l)
Tensor
mD_mnl
=
make_gather_tensor
(
make_gmem_ptr
(
params
.
ptr_D
),
make_shape
(
M
,
N
,
mock_L
),
params
.
dD
,
get_scatter_idx
);
// (m,n,l)
// Use fake shape for bias, it doesn't matter
bool
const
is_bias_needed
=
params
.
ptr_bias
!=
nullptr
;
Tensor
mBias_mnl
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_bias
),
make_shape
(
M
,
N
,
1
),
params
.
dBias
);
Tensor
mScale_mnl
=
make_tensor
(
make_gmem_ptr
(
params
.
ptr_scale
+
params
.
group_offset
[
l_coord
]),
make_shape
(
M
,
N
),
params
.
dScale
);
Tensor
gC_mnl
=
local_tile
(
mC_mnl
,
blk_shape_MNK
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N,m,n,l)
Tensor
gD_mnl
=
local_tile
(
mD_mnl
,
blk_shape_MNK
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N,m,n,l)
Tensor
gC
=
gC_mnl
(
_
,
_
,
m_coord
,
n_coord
,
mock_l_coord
);
// (BLK_M,BLK_N)
Tensor
gD
=
gD_mnl
(
_
,
_
,
m_coord
,
n_coord
,
mock_l_coord
);
// (BLK_M,BLK_N)
Tensor
gC_epi
=
flat_divide
(
gC
,
EpilogueTile
{});
// (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor
gD_epi
=
flat_divide
(
gD
,
EpilogueTile
{});
// (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor
gBias_mnl
=
local_tile
(
mBias_mnl
,
blk_shape_MNK
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N,m,n,l)
Tensor
gScale_mnl
=
local_tile
(
mScale_mnl
,
blk_shape_MNK
,
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
_1
,
X
>
{});
// (BLK_M,BLK_N,m,n,l)
Tensor
gBias
=
gBias_mnl
(
_
,
_
,
m_coord
,
n_coord
,
l_coord
);
// (BLK_M,BLK_N)
Tensor
gScale
=
gScale_mnl
(
_
,
_
,
m_coord
,
n_coord
);
// (BLK_M,BLK_N)
Tensor
gBias_epi
=
flat_divide
(
gBias
,
EpilogueTile
{});
// (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor
gScale_epi
=
flat_divide
(
gScale
,
EpilogueTile
{});
// (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
// Get the smallest tiled copy we can use to retile the accumulators
TiledCopy
tiled_copy_C_atom
=
make_tiled_copy_C_atom
(
Copy_Atom
<
SM90_U32x4_STSM_N
,
cutlass
::
half_t
>
{},
tiled_mma
);
TiledCopy
tiled_r2s
=
make_tiled_copy_S
(
CopyAtomR2S
{},
tiled_copy_C_atom
);
auto
thread_r2s
=
tiled_r2s
.
get_thread_slice
(
thread_idx
);
Tensor
tRS_rAcc
=
thread_r2s
.
retile_S
(
accumulators
);
// ((R2S,R2S_V),MMA_M,MMA_N)
Tensor
tRS_sD
=
thread_r2s
.
partition_D
(
sD
);
// ((R2S,R2S_V),R2S_M,R2S_N)
Tensor
tRS_rD
=
make_tensor
<
ElementAccumulator
>
(
shape
(
tRS_sD
));
// ((R2S,R2S_V),R2S_M,R2S_N)
// Make a tiled copy vectorized along major direction of D
auto
tiled_s2r
=
[
&
]()
{
if
constexpr
(
cutlass
::
gemm
::
detail
::
is_k_major
<
StrideD
>
())
{
constexpr
int
NumThreadsMajor
=
epi_tile_n
/
AlignmentD
;
constexpr
int
NumThreadsMinor
=
cute
::
size
(
tiled_mma
)
/
NumThreadsMajor
;
return
make_tiled_copy
(
CopyAtomS2R
{},
Layout
<
Shape
<
Int
<
NumThreadsMinor
>
,
Int
<
NumThreadsMajor
>>
,
Stride
<
Int
<
NumThreadsMajor
>
,
_1
>>
{},
Layout
<
Shape
<
_1
,
Int
<
AlignmentD
>>>
{});
}
else
if
constexpr
(
cutlass
::
gemm
::
detail
::
is_mn_major
<
StrideD
>
())
{
constexpr
int
NumThreadsMajor
=
epi_tile_m
/
AlignmentD
;
constexpr
int
NumThreadsMinor
=
cute
::
size
(
tiled_mma
)
/
NumThreadsMajor
;
return
make_tiled_copy
(
CopyAtomS2R
{},
Layout
<
Shape
<
Int
<
NumThreadsMajor
>
,
Int
<
NumThreadsMinor
>>
,
Stride
<
_1
,
Int
<
NumThreadsMajor
>>>
{},
Layout
<
Shape
<
Int
<
AlignmentD
>
,
_1
>>
{});
}
else
{
static_assert
(
cute
::
is_void_v
<
StrideD
>
,
"Unsupported D gmem layout."
);
}
}();
auto
thread_s2r
=
tiled_s2r
.
get_thread_slice
(
thread_idx
);
Tensor
tSR_sD
=
thread_s2r
.
partition_S
(
sD
);
// ((S2R,S2R_V),S2R_M,S2R_N)
Tensor
tSR_gD
=
thread_s2r
.
partition_D
(
gD_epi
);
// ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor
tSR_gC
=
thread_s2r
.
partition_D
(
gC_epi
);
// ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor
tSR_gBias
=
thread_s2r
.
partition_D
(
gBias_epi
);
// ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
Tensor
tSR_gScale
=
thread_s2r
.
partition_D
(
gScale_epi
);
// ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// Allocate intermediate registers for a single subtile
Tensor
tSR_rD
=
make_tensor
<
ElementAccumulator
>
(
take
<
0
,
3
>
(
shape
(
tSR_gD
)));
// ((S2R,S2R_V),S2R_M,S2R_N)
Tensor
tSR_rD_final
=
make_tensor
<
ElementD
>
(
shape
(
tSR_rD
));
// ((S2R,S2R_V),S2R_M,S2R_N)
Tensor
tSR_rC
=
make_tensor
<
ElementC
>
(
shape
(
tSR_rD
));
// ((S2R,S2R_V),S2R_M,S2R_N)
Tensor
tSR_rBias
=
make_tensor
<
ElementBias
>
(
tSR_gBias
(
_
,
_
,
_
,
0
,
0
).
layout
());
// ((S2R,S2R_V),S2R_M,S2R_N)
Tensor
tSR_rScale
=
make_tensor
<
ElementScale
>
(
tSR_gScale
(
_
,
_
,
_
,
0
,
0
).
layout
());
// ((S2R,S2R_V),S2R_M,S2R_N)
// Make an identity coordinate tensor for predicating our output MN tile
Tensor
cD
=
make_identity_tensor
(
make_shape
(
unwrap
(
shape
<
0
>
(
gD
)),
unwrap
(
shape
<
1
>
(
gD
))));
Tensor
cD_epi
=
flat_divide
(
cD
,
EpilogueTile
{});
// (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N)
Tensor
tSR_cD
=
thread_s2r
.
partition_D
(
cD_epi
);
// ((S2R,S2R_V),S2R_M,S2R_N,EPI_M,EPI_N)
// epilogue subtile loop
CUTLASS_PRAGMA_UNROLL
for
(
int
epi_m
=
0
;
epi_m
<
size
<
2
>
(
gD_epi
);
++
epi_m
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
epi_n
=
0
;
epi_n
<
size
<
3
>
(
gD_epi
);
++
epi_n
)
{
int
mma_m
=
(
epi_m
*
epi_tile_m
)
/
mma_tile_m
;
int
mma_n
=
(
epi_n
*
epi_tile_n
)
/
mma_tile_n
;
Tensor
tRS_rAcc_mn
=
tRS_rAcc
(
_
,
mma_m
,
mma_n
);
int
epi_n_in_mma
=
epi_n
%
(
mma_tile_n
/
epi_tile_n
);
int
r2s_v
=
epi_n_in_mma
*
size
(
tRS_rD
);
CUTLASS_PRAGMA_UNROLL
for
(
int
epi_v
=
0
;
epi_v
<
size
(
tRS_rD
);
++
epi_v
)
{
tRS_rD
(
epi_v
)
=
tRS_rAcc_mn
(
r2s_v
+
epi_v
);
}
copy
(
tiled_r2s
,
tRS_rD
,
tRS_sD
);
synchronize
();
copy
(
tiled_s2r
,
tSR_sD
,
tSR_rD
);
synchronize
();
Tensor
tSR_gC_mn
=
tSR_gC
(
_
,
_
,
_
,
epi_m
,
epi_n
);
Tensor
tSR_gBias_mn
=
tSR_gBias
(
_
,
_
,
_
,
epi_m
,
epi_n
);
Tensor
tSR_gScale_mn
=
tSR_gScale
(
_
,
_
,
_
,
epi_m
,
epi_n
);
Tensor
tSR_cD_mn
=
tSR_cD
(
_
,
_
,
_
,
epi_m
,
epi_n
);
Tensor
tSR_gD_mn
=
tSR_gD
(
_
,
_
,
_
,
epi_m
,
epi_n
);
if
(
epilogue_op
.
is_source_needed
())
{
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tSR_rD
);
++
m
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
n
=
0
;
n
<
size
<
2
>
(
tSR_rD
);
++
n
)
{
if
(
elem_less
(
tSR_cD_mn
(
0
,
m
,
n
),
make_coord
(
get
<
0
>
(
residue_mnk
),
get
<
1
>
(
residue_mnk
))))
{
copy
(
tSR_gC_mn
(
_
,
m
,
n
),
tSR_rC
(
_
,
m
,
n
));
if
(
is_bias_needed
)
{
copy
(
tSR_gBias_mn
(
_
,
m
,
n
),
tSR_rBias
(
_
,
m
,
n
));
}
copy
(
tSR_gScale_mn
(
_
,
m
,
n
),
tSR_rScale
(
_
,
m
,
n
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
<
0
>
(
tSR_rD
);
++
i
)
{
auto
epi_value
=
epilogue_op
(
tSR_rD
(
i
,
m
,
n
),
tSR_rC
(
i
,
m
,
n
));
if
(
is_bias_needed
)
{
epi_value
+=
static_cast
<
ElementCompute
>
(
tSR_rBias
(
i
,
m
,
n
));
}
tSR_rD_final
(
i
,
m
,
n
)
=
static_cast
<
ElementD
>
(
tSR_rScale
(
i
,
m
,
n
)
*
epi_value
);
}
copy
(
CopyAtomR2G
{},
tSR_rD_final
(
_
,
m
,
n
),
tSR_gD_mn
(
_
,
m
,
n
));
}
}
}
}
else
{
CUTLASS_PRAGMA_UNROLL
for
(
int
m
=
0
;
m
<
size
<
1
>
(
tSR_rD
);
++
m
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
n
=
0
;
n
<
size
<
2
>
(
tSR_rD
);
++
n
)
{
if
(
elem_less
(
tSR_cD_mn
(
0
,
m
,
n
),
make_coord
(
get
<
0
>
(
residue_mnk
),
get
<
1
>
(
residue_mnk
))))
{
if
(
is_bias_needed
)
{
copy
(
tSR_gBias_mn
(
_
,
m
,
n
),
tSR_rBias
(
_
,
m
,
n
));
}
copy
(
tSR_gScale_mn
(
_
,
m
,
n
),
tSR_rScale
(
_
,
m
,
n
));
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
size
<
0
>
(
tSR_rD
);
++
i
)
{
auto
epi_value
=
epilogue_op
(
tSR_rD
(
i
,
m
,
n
));
if
(
is_bias_needed
)
{
epi_value
+=
static_cast
<
ElementCompute
>
(
tSR_rBias
(
i
,
m
,
n
));
}
tSR_rD_final
(
i
,
m
,
n
)
=
static_cast
<
ElementD
>
(
tSR_rScale
(
i
,
m
,
n
)
*
epi_value
);
}
copy
(
CopyAtomR2G
{},
tSR_rD_final
(
_
,
m
,
n
),
tSR_gD_mn
(
_
,
m
,
n
));
}
}
}
}
}
}
}
private:
Params
params
;
};
namespace
detail
{
template
<
class
Element
,
class
MaxVec
>
constexpr
auto
get_vectorized_atomic_add_op
()
{
using
namespace
cute
;
auto
constexpr
MaxVecSize
=
size
(
MaxVec
{});
if
constexpr
(
is_same_v
<
Element
,
cutlass
::
half_t
>
)
{
if
constexpr
(
MaxVecSize
>=
8
)
{
return
SM90_RED_ADD_NOFTZ_F16x2_V4
{};
}
else
if
constexpr
(
MaxVecSize
>=
4
)
{
return
SM90_RED_ADD_NOFTZ_F16x2_V2
{};
}
else
if
constexpr
(
MaxVecSize
>=
2
)
{
return
SM70_RED_ADD_NOFTZ_F16x2
{};
}
else
{
return
SM70_RED_ADD_NOFTZ_F16
{};
}
}
else
if
constexpr
(
is_same_v
<
Element
,
cutlass
::
bfloat16_t
>
)
{
if
constexpr
(
MaxVecSize
>=
8
)
{
return
SM90_RED_ADD_NOFTZ_BF16x2_V4
{};
}
else
if
constexpr
(
MaxVecSize
>=
4
)
{
return
SM90_RED_ADD_NOFTZ_BF16x2_V2
{};
}
else
if
constexpr
(
MaxVecSize
>=
2
)
{
return
SM90_RED_ADD_NOFTZ_BF16x2
{};
}
else
{
return
SM90_RED_ADD_NOFTZ_BF16
{};
}
}
else
{
// non-vectorized atomic add for all other types until supported
return
TypedAtomicAdd
<
Element
>
{};
}
}
}
// namespace detail
template
<
class
TileShape
,
class
ElementC
,
class
StrideC
,
class
ElementD
,
class
StrideD
,
class
ElementAccumulator
,
class
ElementCompute
,
class
ElementBias
,
class
StrideBias
,
class
ElementScale
,
class
StrideScale
>
struct
EpilogueMoeFusedFinalizeBuilder
{
// assuming cooperative kernel schedule
using
EpiTileN
=
decltype
(
cute
::
min
(
size
<
1
>
(
TileShape
{}),
_32
{}));
using
EpilogueTile
=
Shape
<
_128
,
EpiTileN
>
;
// Output of linear combination is ElementCompute instead of ElementD
// since we will be doing more computate on it, no need to cast yet.
using
ThreadEpilogueOp
=
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementCompute
,
1
,
ElementAccumulator
,
ElementCompute
,
cutlass
::
epilogue
::
thread
::
ScaleType
::
Default
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
ElementC
>
;
using
SmemLayoutAtomD
=
decltype
(
detail
::
sm90_get_epilogue_smem_swizzle_layout_atom
<
StrideD
,
ElementAccumulator
,
EpilogueTile
>
());
using
CopyAtomR2S
=
decltype
(
detail
::
sm90_get_smem_store_op_for_accumulator
<
StrideD
,
ElementAccumulator
>
());
using
CopyAtomS2R
=
DefaultCopy
;
using
CopyAtomR2G
=
decltype
(
detail
::
get_vectorized_atomic_add_op
<
ElementD
,
EpiTileN
>
());
template
<
class
EpilogueOp
>
struct
Sm90TmaWarpSpecializedAdapterWithSmemStorage
:
detail
::
Sm90TmaWarpSpecializedAdapter
<
EpilogueOp
>
{
// We need to override this one using declaration because otherwise we double up on the smem
using
TensorMapStorage
=
typename
EpilogueOp
::
TensorMapStorage
;
using
Base
=
detail
::
Sm90TmaWarpSpecializedAdapter
<
EpilogueOp
>
;
CUTLASS_HOST_DEVICE
Sm90TmaWarpSpecializedAdapterWithSmemStorage
(
typename
EpilogueOp
::
Params
const
&
params
,
[[
maybe_unused
]]
typename
Base
::
TensorStorage
&
shared_tensors
)
:
Base
(
params
)
{
}
// These functions depend on the type of TensorMapStorage
template
<
bool
IsLoad
>
CUTLASS_DEVICE
void
tensormaps_perform_update
([[
maybe_unused
]]
TensorMapStorage
&
shared_tensormap
,
[[
maybe_unused
]]
typename
EpilogueOp
::
Params
const
&
params
,
[[
maybe_unused
]]
cute
::
TmaDescriptor
const
*
tensormap
,
[[
maybe_unused
]]
int32_t
next_batch
)
{
}
template
<
bool
IsLoad
>
CUTLASS_DEVICE
void
tensormaps_cp_fence_release
([[
maybe_unused
]]
TensorMapStorage
&
shared_tensormap
,
[[
maybe_unused
]]
cute
::
TmaDescriptor
const
*
tensormap
,
[[
maybe_unused
]]
uint32_t
lane_predicate
)
{
}
};
using
CollectiveOp
=
Sm90TmaWarpSpecializedAdapterWithSmemStorage
<
EpilogueMoeFusedFinalize
<
StrideC
,
ElementD
,
StrideD
,
ThreadEpilogueOp
,
ElementBias
,
StrideBias
,
ElementScale
,
StrideScale
,
EpilogueTile
,
SmemLayoutAtomD
,
CopyAtomR2S
,
CopyAtomS2R
,
CopyAtomR2G
>>
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace collective
}
// namespace epilogue
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/thread/fused_activations.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Functor performing linear combination with a maximum operation used by epilogues.
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/scale_type.h"
#include "cutlass/functional.h"
#include "cutlass/half.h"
#include "cutlass/numeric_conversion.h"
#include "cutlass/numeric_types.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
epilogue
{
namespace
thread
{
/////////////////////////////////////////////////////////////////////////////////////////////////
__forceinline__
__device__
float
copysignf_pos
(
float
a
,
float
b
)
{
float
r
;
r
=
__int_as_float
(
__float_as_int
(
a
)
|
(
__float_as_int
(
b
)
&
0x80000000
));
return
r
;
}
__forceinline__
__device__
float
tanh_opt
(
float
x
)
{
#if (__CUDACC_VER_MAJOR__ < 11) || (__CUDA_ARCH__ < 750)
float
const
exp_val
=
-
1.
f
*
fabs
(
2
*
x
);
return
copysignf_pos
((
1.0
f
-
__expf
(
exp_val
))
/
(
__expf
(
exp_val
)
+
1.0
f
),
x
);
#else
return
fast_tanh
(
x
);
#endif
}
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
>
struct
GELU_taylor
<
float
>
{
static
bool
const
kIsHeavy
=
true
;
CUTLASS_DEVICE
float
operator
()(
float
const
&
z
)
const
{
float
k0
=
float
(
0.7978845608028654
);
float
k1
=
float
(
0.044715
);
return
float
(
cutlass
::
constants
::
half
<
float
>
()
*
z
*
(
cutlass
::
constants
::
one
<
float
>
()
+
tanh_opt
(
k0
*
z
*
(
cutlass
::
constants
::
one
<
float
>
()
+
k1
*
z
*
z
))));
}
using
Params
=
LinearCombinationGenericParams
<
float
>
;
CUTLASS_DEVICE
float
operator
()(
float
const
&
scalar
,
Params
const
&
params_
)
const
{
return
this
->
operator
()(
scalar
);
}
};
}
// namespace thread
}
// namespace epilogue
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_per_row_per_col_scale.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue visitor for threadblock scoped INT8 GEMMs that uses one scaling factor per row, and one per column.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h
*/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/arch/memory.h"
#include "cutlass/arch/memory_sm75.h"
#include "cutlass/cutlass.h"
#include "cutlass/fast_math.h"
#include "cutlass/numeric_conversion.h"
#include "tensorrt_llm/common/quantization.h"
namespace
tk
=
tensorrt_llm
::
common
;
namespace
cutlass
{
namespace
epilogue
{
namespace
threadblock
{
template
<
typename
ThreadblockShape_
,
int
ThreadCount
,
typename
ScaleTileIterator_
,
typename
OutputTileIterator_
,
typename
ElementAccumulator_
,
typename
ElementCompute_
,
typename
ElementwiseFunctor_
,
bool
UseMasking_
=
false
>
class
EpilogueVisitorPerRowPerCol
{
public:
using
ThreadblockShape
=
ThreadblockShape_
;
static
int
const
kThreadCount
=
ThreadCount
;
using
ScaleTileIterator
=
ScaleTileIterator_
;
using
OutputTileIterator
=
OutputTileIterator_
;
using
ElementwiseFunctor
=
ElementwiseFunctor_
;
static
int
const
kIterations
=
OutputTileIterator
::
kIterations
;
static
int
const
kElementsPerAccess
=
OutputTileIterator
::
kElementsPerAccess
;
using
ElementOutput
=
typename
OutputTileIterator
::
Element
;
using
LayoutOutput
=
cutlass
::
layout
::
RowMajor
;
using
ElementAccumulator
=
ElementAccumulator_
;
using
AlphaScaleElementType
=
typename
ScaleTileIterator
::
Element
;
using
ElementCompute
=
ElementCompute_
;
using
AccumulatorFragment
=
Array
<
ElementAccumulator
,
kElementsPerAccess
>
;
using
ComputeFragment
=
Array
<
ElementCompute_
,
kElementsPerAccess
>
;
using
OutputVector
=
Array
<
ElementOutput
,
kElementsPerAccess
>
;
static
int
const
kThreadsPerRow
=
OutputTileIterator
::
ThreadMap
::
Detail
::
kAccessWidth
;
static
bool
const
kHasMultiStepsInRow
=
(
OutputTileIterator
::
ThreadMap
::
Iterations
::
kColumn
>
1
);
/// Argument structure
struct
Arguments
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
Arguments
()
:
batch_stride_alpha
(
0
)
,
batch_stride_C
(
0
)
,
batch_stride_D
(
0
)
{
}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
)
:
elementwise
(
elementwise_
)
,
batch_stride_alpha
(
0
)
,
batch_stride_C
(
0
)
,
batch_stride_D
(
0
)
{
}
Arguments
(
typename
ElementwiseFunctor
::
Params
elementwise_
,
int64_t
batch_stride_alpha_
,
int64_t
batch_stride_C_
,
int64_t
batch_stride_D_
)
:
elementwise
(
elementwise_
)
,
batch_stride_alpha
(
batch_stride_alpha_
)
,
batch_stride_C
(
batch_stride_C_
)
,
batch_stride_D
(
batch_stride_D_
)
{
}
};
struct
Params
{
typename
ElementwiseFunctor
::
Params
elementwise
;
int64_t
batch_stride_alpha
;
int64_t
batch_stride_C
;
int64_t
batch_stride_D
;
//
// Methods
//
CUTLASS_HOST_DEVICE
Params
()
{}
CUTLASS_HOST_DEVICE
Params
(
Arguments
const
&
args
)
:
elementwise
(
args
.
elementwise
)
,
batch_stride_alpha
(
args
.
batch_stride_alpha
)
,
batch_stride_C
(
args
.
batch_stride_C
)
,
batch_stride_D
(
args
.
batch_stride_D
)
{
}
};
/// Shared storage
struct
SharedStorage
{
};
private:
Params
const
&
params_
;
SharedStorage
&
shared_storage_
;
MatrixCoord
extent_
;
MatrixCoord
extent_real_
;
ElementwiseFunctor
elementwise_
;
bool
const
per_token_quant_
;
bool
const
per_channel_quant_
;
AlphaScaleElementType
*
ptr_alpha_row_
;
AlphaScaleElementType
*
ptr_alpha_col_
;
ScaleTileIterator
iterator_alpha_col_
;
OutputTileIterator
iterator_C_
;
OutputTileIterator
iterator_D_
;
AlphaScaleElementType
element_alpha_row_
=
1.0
f
;
AlphaScaleElementType
element_alpha_col_
=
1.0
f
;
typename
ScaleTileIterator
::
Fragment
fragment_alpha_col_
;
typename
OutputTileIterator
::
Fragment
fragment_C_
;
typename
OutputTileIterator
::
Fragment
fragment_D_
;
ElementAccumulator
beta_
;
int
column_offset_
;
MatrixCoord
thread_offset_
;
public:
CUTLASS_DEVICE
EpilogueVisitorPerRowPerCol
(
Params
const
&
params
,
SharedStorage
&
shared_storage
,
cutlass
::
MatrixCoord
const
&
problem_size
,
int
thread_idx
,
int
warp_idx
,
int
lane_idx
,
typename
ScaleTileIterator
::
Params
params_alpha_col
,
typename
OutputTileIterator
::
Params
params_C
,
typename
OutputTileIterator
::
Params
params_D
,
tk
::
QuantMode
quant_option
,
AlphaScaleElementType
*
ptr_alpha_row
,
AlphaScaleElementType
*
ptr_alpha_col
,
typename
OutputTileIterator
::
Element
*
ptr_C
,
typename
OutputTileIterator
::
Element
*
ptr_D
,
cutlass
::
MatrixCoord
const
&
threadblock_offset
=
cutlass
::
MatrixCoord
(
0
,
0
),
int
column_offset
=
0
,
cutlass
::
MatrixCoord
const
&
problem_size_real
=
cutlass
::
MatrixCoord
(
0
,
0
))
:
params_
(
params
)
,
shared_storage_
(
shared_storage
)
,
extent_
(
problem_size
)
,
elementwise_
(
params
.
elementwise
)
,
per_token_quant_
(
quant_option
.
hasPerTokenScaling
())
,
per_channel_quant_
(
quant_option
.
hasPerChannelScaling
())
,
ptr_alpha_row_
(
ptr_alpha_row
)
,
ptr_alpha_col_
(
ptr_alpha_col
)
,
iterator_alpha_col_
(
params_alpha_col
,
ptr_alpha_col
,
problem_size
,
thread_idx
,
threadblock_offset
)
,
iterator_C_
(
params_C
,
ptr_C
,
problem_size
,
thread_idx
,
threadblock_offset
)
,
iterator_D_
(
params_D
,
ptr_D
,
problem_size
,
thread_idx
,
threadblock_offset
)
,
extent_real_
(
problem_size_real
)
{
beta_
=
(
params
.
elementwise
.
beta_ptr
?
*
params
.
elementwise
.
beta_ptr
:
params
.
elementwise
.
beta
);
if
(
beta_
==
ElementAccumulator
())
{
iterator_C_
.
clear_mask
();
}
if
(
!
per_channel_quant_
&&
(
ptr_alpha_col_
!=
nullptr
))
{
element_alpha_col_
=
*
ptr_alpha_col_
;
}
if
(
!
per_token_quant_
&&
(
ptr_alpha_row_
!=
nullptr
))
{
element_alpha_row_
=
*
ptr_alpha_row_
;
}
}
/// Helper to indicate split-K behavior
CUTLASS_DEVICE
void
set_k_partition
(
int
split_k_index
,
///< Index of this threadblock within split-K partitioned scheme
int
split_k_slices
)
{
///< Total number of split-K slices
}
/// Called to set the batch index
CUTLASS_DEVICE
void
set_batch_index
(
int
batch_idx
)
{
iterator_alpha_col_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_alpha
);
iterator_C_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_C
);
iterator_D_
.
add_pointer_offset
(
batch_idx
*
params_
.
batch_stride_D
);
}
/// Called at the start of the epilogue just before iterating over accumulator slices
CUTLASS_DEVICE
void
begin_epilogue
()
{
if
(
per_channel_quant_
)
{
iterator_alpha_col_
.
load
(
fragment_alpha_col_
);
}
}
/// Called at the start of one step before starting accumulator exchange
CUTLASS_DEVICE
void
begin_step
(
int
step_idx
)
{
fragment_D_
.
clear
();
fragment_C_
.
clear
();
if
(
elementwise_
.
kScale
!=
cutlass
::
epilogue
::
thread
::
ScaleType
::
OnlyAlphaScaling
)
{
iterator_C_
.
load
(
fragment_C_
);
++
iterator_C_
;
}
}
/// Called at the start of a row
CUTLASS_DEVICE
void
begin_row
(
int
row_idx
)
{
// load alpha_row in begin_step only when per token(row) scaling is used
if
(
per_token_quant_
)
{
int
thread_offset_row
=
iterator_D_
.
thread_start_row
()
+
OutputTileIterator
::
ThreadMap
::
iteration_offset
(
row_idx
).
row
();
arch
::
global_load
<
AlphaScaleElementType
,
sizeof
(
AlphaScaleElementType
)
>
(
element_alpha_row_
,
ptr_alpha_row_
+
thread_offset_row
,
thread_offset_row
<
extent_
.
row
());
}
}
/// Called after accumulators have been exchanged for each accumulator vector
CUTLASS_DEVICE
void
visit
(
int
iter_idx
,
int
row_idx
,
int
column_idx
,
int
frag_idx
,
AccumulatorFragment
const
&
accum
)
{
NumericArrayConverter
<
ElementCompute
,
ElementAccumulator
,
kElementsPerAccess
>
source_converter
;
ComputeFragment
result
=
source_converter
(
accum
);
if
(
per_channel_quant_
)
{
ComputeFragment
alpha_col
=
reinterpret_cast
<
ComputeFragment
*>
(
&
fragment_alpha_col_
)[
column_idx
];
result
=
per_token_channel_scale_accumulator_
(
result
,
alpha_col
,
element_alpha_row_
);
}
else
{
result
=
per_token_scale_accumulator_
(
result
,
element_alpha_col_
,
element_alpha_row_
);
}
// Convert to the output
NumericArrayConverter
<
ElementOutput
,
ElementCompute
,
kElementsPerAccess
>
output_converter
;
OutputVector
&
output
=
reinterpret_cast
<
OutputVector
*>
(
&
fragment_D_
)[
frag_idx
];
output
=
output_converter
(
result
);
}
/// Called at the end of a row
CUTLASS_DEVICE
void
end_row
(
int
row_idx
)
{}
/// Called after all accumulator elements have been visited
CUTLASS_DEVICE
void
end_step
(
int
step_idx
)
{
iterator_D_
.
store
(
fragment_D_
);
++
iterator_D_
;
}
/// Called after all steps have been completed
CUTLASS_DEVICE
void
end_epilogue
()
{}
private:
CUTLASS_DEVICE
ComputeFragment
per_token_channel_scale_accumulator_
(
ComputeFragment
const
&
accum
,
ComputeFragment
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
[
i
]
*
scale_row
);
}
return
result
;
}
CUTLASS_DEVICE
ComputeFragment
per_token_scale_accumulator_
(
ComputeFragment
const
&
accum
,
AlphaScaleElementType
const
&
scale_col
,
AlphaScaleElementType
const
&
scale_row
)
{
ComputeFragment
result
;
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
ComputeFragment
::
kElements
;
++
i
)
{
result
[
i
]
=
accum
[
i
]
*
(
scale_col
*
scale_row
);
}
return
result
;
}
};
}
// namespace threadblock
}
// namespace epilogue
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue/threadblock/epilogue_tensor_op_int32.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*! \file
\brief Epilogue for threadblock scoped GEMMs using Tensor Ops.
The epilogue rearranges the result of a matrix product through shared memory to match canonical
tensor layouts in global memory. Epilogues support conversion and reduction operations.
original file: 3rdparty/cutlass/include/cutlass/epilogue/threadblock/default_epilogue_tensor_op.h
*/
#pragma once
#include "cutlass/array.h"
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"
#include "cutlass/platform/platform.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_clamp.h"
#include "cutlass/epilogue/thread/linear_combination_gelu.h"
#include "cutlass/epilogue/thread/linear_combination_hardswish.h"
#include "cutlass/epilogue/thread/linear_combination_planar_complex.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_relu0.h"
#include "cutlass/epilogue/thread/linear_combination_sigmoid.h"
#include "cutlass/epilogue/thread/conversion_op.h"
#include "cutlass/epilogue/thread/reduction_op.h"
#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h"
#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_affine.h"
#include "cutlass/epilogue/threadblock/predicated_tile_iterator_strided_dgrad.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator.h"
#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h"
#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h"
#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h"
#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h"
#include "cutlass/epilogue/threadblock/epilogue.h"
#include "cutlass/epilogue/threadblock/interleaved_epilogue.h"
#include "cutlass/layout/permute.h"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
epilogue
{
namespace
threadblock
{
////////////////////////////////////////////////////////////////////////////////
namespace
detail
{
/// Partial specialization for bfloat16_t <= int32_t x 8 epilogues avoids shared memory bank conflicts.
template
<
typename
ThreadblockShape
,
typename
WarpShape
,
typename
InstructionShape
,
typename
ThreadMap
>
struct
DefaultIteratorsTensorOp
<
cutlass
::
bfloat16_t
,
int32_t
,
8
,
ThreadblockShape
,
WarpShape
,
InstructionShape
,
ThreadMap
>
{
using
WarpTileIterator
=
cutlass
::
epilogue
::
warp
::
TileIteratorTensorOpMixed
<
WarpShape
,
InstructionShape
,
int32_t
,
32
,
16
,
8
,
8
>
;
using
SharedLoadIterator
=
cutlass
::
epilogue
::
threadblock
::
SharedLoadIteratorMixed
<
ThreadMap
,
int32_t
,
32
,
16
,
8
,
8
>
;
static
int
const
kFragmentsPerIteration
=
2
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
/// Tile iterator used to load output tile from shared memory in epilogue.
///
/// Satisfies: ReadableTileIterator
///
template
<
typename
ThreadMap_
///< Thread map (concept: OutputTileThreadMap)
>
class
SharedLoadIteratorMixed
<
ThreadMap_
,
int32_t
,
32
,
16
,
8
,
8
>
{
public:
using
ThreadMap
=
ThreadMap_
;
using
Shape
=
typename
ThreadMap
::
Shape
;
using
Element
=
int32_t
;
using
Layout
=
layout
::
RowMajor
;
using
TensorRef
=
TensorRef
<
Element
,
Layout
>
;
using
ConstTensorRef
=
typename
TensorRef
::
ConstTensorRef
;
using
Index
=
typename
Layout
::
Index
;
using
LongIndex
=
typename
Layout
::
LongIndex
;
using
TensorCoord
=
MatrixCoord
;
static
int
const
kElementsPerAccess
=
ThreadMap
::
kElementsPerAccess
;
static
int
const
kAlignment
=
ThreadMap
::
kElementsPerAccess
*
sizeof_bits
<
Element
>::
value
/
8
;
static
int
const
kThreads
=
ThreadMap
::
kThreads
;
/// Fragment object
using
Fragment
=
Array
<
Element
,
ThreadMap
::
Iterations
::
kColumn
*
ThreadMap
::
Iterations
::
kRow
*
ThreadMap
::
Iterations
::
kGroup
*
ThreadMap
::
Iterations
::
kCluster
*
ThreadMap
::
kElementsPerAccess
>
;
/// Memory access size
using
AccessType
=
AlignedArray
<
Element
,
ThreadMap
::
kElementsPerAccess
,
kAlignment
>
;
/// Vector type used for SMEM loads
using
LoadType
=
AlignedArray
<
Element
,
const_min
(
128
/
sizeof_bits
<
Element
>::
value
,
ThreadMap
::
kElementsPerAccess
),
const_min
(
16
,
kAlignment
)
>
;
static
int
const
kLoadsPerAccess
=
AccessType
::
kElements
/
LoadType
::
kElements
;
private:
//
// Data members
//
/// Byte-level pointer
LoadType
const
*
pointers_
[
kLoadsPerAccess
];
/// Stride along adjacent rows in units of LoadType
int
stride_
;
public:
//
// Methods
//
/// Constructor
CUTLASS_DEVICE
SharedLoadIteratorMixed
(
TensorRef
ref
,
int
thread_idx
)
:
stride_
((
ref
.
stride
(
0
)
/
LoadType
::
kElements
))
{
TensorCoord
thread_offset
=
ThreadMap
::
initial_offset
(
thread_idx
);
// Initialize pointers
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kLoadsPerAccess
;
++
i
)
{
pointers_
[
i
]
=
reinterpret_cast
<
LoadType
const
*>
(
ref
.
data
());
int
col_idx
=
(
thread_offset
.
column
()
/
kElementsPerAccess
)
*
kLoadsPerAccess
;
int
bank_offset
=
(
col_idx
*
int
(
sizeof
(
LoadType
))
/
128
)
%
kLoadsPerAccess
;
col_idx
+=
(
bank_offset
+
i
)
%
kLoadsPerAccess
;
pointers_
[
i
]
+=
thread_offset
.
row
()
*
stride_
+
col_idx
;
}
}
/// Adds a pointer offset in units of Element
CUTLASS_HOST_DEVICE
void
add_pointer_offset
(
LongIndex
pointer_offset
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kLoadsPerAccess
;
++
i
)
{
pointers_
[
i
]
+=
pointer_offset
/
LoadType
::
kElements
;
}
}
CUTLASS_DEVICE
void
add_tile_offset
(
TensorCoord
const
&
offset
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
i
=
0
;
i
<
kLoadsPerAccess
;
++
i
)
{
pointers_
[
i
]
+=
offset
.
row
()
*
Shape
::
kRow
*
stride_
+
offset
.
column
()
*
Shape
::
kColumn
/
LoadType
::
kElements
;
}
}
/// Loads a fragment from memory
CUTLASS_DEVICE
void
load_with_pointer_offset
(
Fragment
&
frag
,
Index
pointer_offset
)
const
{
CUTLASS_PRAGMA_UNROLL
for
(
int
cluster
=
0
;
cluster
<
ThreadMap
::
Iterations
::
kCluster
;
++
cluster
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
group
=
0
;
group
<
ThreadMap
::
Iterations
::
kGroup
;
++
group
)
{
CUTLASS_PRAGMA_UNROLL
for
(
int
row
=
0
;
row
<
ThreadMap
::
Iterations
::
kRow
;
++
row
)
{
int
row_ptr_offset
=
row
*
ThreadMap
::
Delta
::
kRow
*
stride_
+
group
*
ThreadMap
::
Delta
::
kGroup
*
stride_
+
cluster
*
ThreadMap
::
Delta
::
kCluster
*
stride_
+
pointer_offset
/
LoadType
::
kElements
;
int
frag_row_idx
=
(
row
+
ThreadMap
::
Iterations
::
kRow
*
(
group
+
ThreadMap
::
Iterations
::
kGroup
*
cluster
));
LoadType
*
frag_ptr
=
reinterpret_cast
<
LoadType
*>
(
&
frag
);
CUTLASS_PRAGMA_UNROLL
for
(
int
column
=
0
;
column
<
ThreadMap
::
Iterations
::
kColumn
;
++
column
)
{
int
frag_idx
=
frag_row_idx
*
ThreadMap
::
Iterations
::
kColumn
+
column
;
CUTLASS_PRAGMA_UNROLL
for
(
int
v
=
0
;
v
<
kLoadsPerAccess
;
++
v
)
{
int
vector_idx
=
(
column
*
ThreadMap
::
Delta
::
kColumn
/
kElementsPerAccess
*
kLoadsPerAccess
);
LoadType
const
*
memory_pointer
=
pointers_
[
v
]
+
row_ptr_offset
;
frag_ptr
[
frag_idx
*
kLoadsPerAccess
+
v
]
=
memory_pointer
[
vector_idx
];
}
}
}
}
}
}
/// Loads a fragment
CUTLASS_DEVICE
void
load
(
Fragment
&
frag
)
const
{
load_with_pointer_offset
(
frag
,
0
);
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace threadblock
}
// namespace epilogue
}
// namespace cutlass
////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/epilogue_helpers.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* @file epilogue_helpers.h
*
* This file includes types for the epilogues. The empty structs exist so we can signal to template
* code the type of epilogue we want to run, and let the underlying code specify the details such as
* element types, accumulator type and elements per vector access.
*
*/
#pragma once
#include "cutlass/epilogue/thread/linear_combination.h"
#include "cutlass/epilogue/thread/linear_combination_generic.h"
#include "cutlass/epilogue/thread/linear_combination_relu.h"
#include "cutlass/epilogue/thread/linear_combination_silu.h"
#include "cutlass_extensions/epilogue/thread/fused_activations.h"
#include <cutlass/epilogue/fusion/operations.hpp>
namespace
tensorrt_llm
{
namespace
cutlass_extensions
{
struct
EpilogueOpBiasSilu
{
};
struct
EpilogueOpBiasReLU
{
};
struct
EpilogueOpBiasFtGelu
{
};
struct
EpilogueOpBias
{
};
struct
EpilogueOpDefaultSilu
{
};
struct
EpilogueOpDefaultReLU
{
};
struct
EpilogueOpDefaultFtGelu
{
};
struct
EpilogueOpDefault
{
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
,
typename
Op
>
struct
Epilogue
{
static_assert
(
sizeof
(
ElementType
)
==
0
,
"Unrecognized Epilogue Tag"
);
};
constexpr
auto
BiasScaleMode
=
cutlass
::
epilogue
::
thread
::
ScaleType
::
NoBetaScaling
;
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBiasSilu
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationSilu
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
BiasScaleMode
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBiasReLU
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationRelu
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
BiasScaleMode
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBiasFtGelu
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationGeneric
<
cutlass
::
epilogue
::
thread
::
GELU_taylor
,
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
BiasScaleMode
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
true
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpBias
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
BiasScaleMode
>
;
};
constexpr
auto
DefaultScaleMode
=
cutlass
::
epilogue
::
thread
::
ScaleType
::
Default
;
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpDefaultSilu
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationSilu
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
DefaultScaleMode
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpDefaultReLU
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationRelu
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
DefaultScaleMode
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpDefaultFtGelu
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombinationGeneric
<
cutlass
::
epilogue
::
thread
::
GELU_taylor
,
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
DefaultScaleMode
,
cutlass
::
FloatRoundStyle
::
round_to_nearest
,
true
>
;
};
template
<
typename
ElementType
,
int
ElementsPerVectorAccess
,
typename
ElementAccumulator
>
struct
Epilogue
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
EpilogueOpDefault
>
{
using
Op
=
cutlass
::
epilogue
::
thread
::
LinearCombination
<
ElementType
,
ElementsPerVectorAccess
,
ElementAccumulator
,
ElementAccumulator
,
DefaultScaleMode
>
;
};
}
// namespace cutlass_extensions
}
// namespace tensorrt_llm
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/arch/mma.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/collective/builders/sm90_common.inl"
// SM90 Collective Builders should be used only starting CUDA 12.0
#if (__CUDACC_VER_MAJOR__ >= 12)
#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
#endif
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace cutlass::gemm::collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace detail
{
// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count.
template <int CapacityBytes, class ElementA, class ElementB, class TileShapeMNK, bool SwapAB, int carveout_bytes>
constexpr int compute_stage_count_or_override_gated(StageCountAutoCarveout<carveout_bytes> stage_count)
{
// 32 bytes to account for barriers etc.
constexpr int stage_barrier_bytes = 32;
constexpr int a_bits = static_cast<int>(sizeof_bits<ElementA>::value);
constexpr int b_bits = static_cast<int>(sizeof_bits<ElementB>::value);
constexpr int stage_bytes = [&]() -> int
{
if constexpr (SwapAB)
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8 + stage_barrier_bytes;
}
else
{
return (a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) / 8
+ (b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{}) * 2) / 8 + stage_barrier_bytes;
}
}();
return (CapacityBytes - carveout_bytes) / stage_bytes;
}
} // namespace detail
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<(cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecialized>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpong>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>) &¬ detail::
is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>()>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Should meet TMA alignment requirement\n");
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperative>);
static constexpr bool IsFP8Input = detail::is_input_fp8<ElementA, ElementB>();
static_assert(!IsFP8Input || (IsFP8Input && !IsArrayOfPointersGemm),
"Kernel[Array/Group]TmaWarpSpecializedCooperative is only compatible with FP8 FastAccum version right now\n");
// For fp32 types, map to tf32 MMA value type
using MmaElementA = cute::conditional_t<cute::is_same_v<ElementA, float>, tfloat32_t, ElementA>;
using MmaElementB = cute::conditional_t<cute::is_same_v<ElementB, float>, tfloat32_t, ElementB>;
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<MmaElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<MmaElementB, GmemLayoutB>();
using AtomLayoutMNK = cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperative>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector<MmaElementA, MmaElementB,
ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, MmaElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, MmaElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, MmaElementA, MmaElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
/* For FP8 use a separate mainloop compared to other datatypes */
cute::conditional_t<IsFP8Input,
MainloopSm90TmaGmmaWarpSpecializedFP8<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
// GMMA_TMA_WS_FP8_FAST_ACCUM_SS
template <class ElementA, class GmemLayoutA, int AlignmentA, class ElementB, class GmemLayoutB, int AlignmentB,
class ElementAccumulator, class TileShape_MNK, class ClusterShape_MNK, class StageCountType,
class KernelScheduleType, template <class /* ElementCompute */> class Activation, bool SwapAB>
struct CollectiveBuilderGated<arch::Sm90, arch::OpClassTensorOp, ElementA, GmemLayoutA, AlignmentA, ElementB,
GmemLayoutB, AlignmentB, ElementAccumulator, TileShape_MNK, ClusterShape_MNK, StageCountType, KernelScheduleType,
Activation, SwapAB,
cute::enable_if_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedPingpongFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>>>
{
static_assert(is_static<TileShape_MNK>::value);
static_assert(is_static<ClusterShape_MNK>::value);
static_assert(detail::is_aligned<ElementA, AlignmentA, ElementB, AlignmentB, detail::tma_alignment_bytes>(),
"Not meet TMA alignment requirement yet\n");
static_assert(
detail::is_input_fp8<ElementA, ElementB>(), "Only FP8 datatypes are compatible with these kernel schedules\n");
// Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder
static_assert(!detail::is_use_rmem_A<ElementA, GmemLayoutA, ElementB, GmemLayoutB>(),
"Not supported for fp8 non-TN warp specialized kernels yet\n");
#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED
static_assert(cutlass::detail::dependent_false<ElementA>, "Unsupported Toolkit for SM90 Collective Builder\n");
#endif
static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A<ElementA, GmemLayoutA>();
static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B<ElementB, GmemLayoutB>();
static constexpr bool IsArrayOfPointersGemm
= (cute::is_same_v<KernelScheduleType, KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum>);
using AtomLayoutMNK
= cute::conditional_t<cute::is_same_v<KernelScheduleType, KernelTmaWarpSpecializedCooperativeFP8FastAccum>
|| IsArrayOfPointersGemm,
Layout<Shape<_2, _1, _1>>, Layout<Shape<_1, _1, _1>>>;
using TiledMma = decltype(cute::make_tiled_mma(
cute::GMMA::ss_op_selector<ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(),
AtomLayoutMNK{}));
using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{})));
using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{})));
using SmemLayoutAtomA = decltype(detail::ss_smem_selector<GmmaMajorA, ElementA,
decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
using SmemLayoutAtomB = decltype(detail::ss_smem_selector<GmmaMajorB, ElementB,
decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>());
static constexpr int PipelineStages
= detail::compute_stage_count_or_override_gated<detail::sm90_smem_capacity_bytes, ElementA, ElementB,
TileShape_MNK, SwapAB>(StageCountType{});
using DispatchPolicy = cute::conditional_t<IsArrayOfPointersGemm,
MainloopSm90ArrayTmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>,
MainloopSm90TmaGmmaWarpSpecialized<PipelineStages, ClusterShape_MNK, KernelScheduleType>>;
using SmemCopyAtomA = void;
using SmemCopyAtomB = void;
using CollectiveOp = CollectiveMmaGated<DispatchPolicy, TileShape_MNK, ElementA, TagToStrideA_t<GmemLayoutA>,
ElementB, TagToStrideB_t<GmemLayoutB>, TiledMma, GmemTiledCopyA, SmemLayoutAtomA, SmemCopyAtomA, cute::identity,
GmemTiledCopyB, SmemLayoutAtomB, SmemCopyAtomB, cute::identity, Activation, SwapAB>;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/////////////////////////////////////////////////////////////////////////////////////////////////
} // namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_builder_gated.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass_extensions/gemm/collective/collective_mma_gated.hpp"
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
class
ArchTag
,
class
OpClass
,
class
ElementA
,
class
GmemLayoutA
,
int
AlignmentA
,
class
ElementB
,
class
GmemLayoutB
,
int
AlignmentB
,
class
ElementAccumulator
,
class
TileShape_MNK
,
class
ClusterShape_MNK
,
class
StageCountType
,
class
KernelScheduleType
,
template
<
class
/* ElementCompute */
>
class
Activation
,
bool
SwapAB
=
false
,
class
Enable
=
void
>
struct
CollectiveBuilderGated
{
static_assert
(
sizeof
(
ElementA
)
==
0
,
"Could not build a collective for given parameters."
);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/builders/sm90_gmma_builder_gated.inl"
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/collective_mma_gated.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cutlass/detail/dependent_false.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
class
DispatchPolicy
,
class
TileShape
,
class
ElementA
,
class
StrideA
,
class
ElementB
,
class
StrideB
,
class
TiledMma
,
class
GmemTiledCopyA
,
class
SmemLayoutAtomA
,
class
SmemCopyAtomA
,
class
TransformA
,
class
GmemTiledCopyB
,
class
SmemLayoutAtomB
,
class
SmemCopyAtomB
,
class
TransformB
,
template
<
class
/* ElementCompute */
>
class
Activation
,
bool
SwapAB
=
false
>
struct
CollectiveMmaGated
{
static_assert
(
cutlass
::
detail
::
dependent_false
<
ElementA
>
,
"Could not find a mainloop specialization."
);
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp"
#include "cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp"
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/pipeline/pipeline.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
using
namespace
cute
;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template
<
int
Stages
,
class
ClusterShape
,
class
KernelSchedule
,
class
TileShape_
,
class
ElementA_
,
class
StrideA_
,
class
ElementB_
,
class
StrideB_
,
class
TiledMma_
,
class
GmemTiledCopyA_
,
class
SmemLayoutAtomA_
,
class
SmemCopyAtomA_
,
class
TransformA_
,
class
GmemTiledCopyB_
,
class
SmemLayoutAtomB_
,
class
SmemCopyAtomB_
,
class
TransformB_
,
template
<
class
/* ElementCompute */
>
class
Activation_
,
bool
SwapAB_
>
struct
CollectiveMmaGated
<
MainloopSm90TmaGmmaWarpSpecialized
<
Stages
,
ClusterShape
,
KernelSchedule
>
,
TileShape_
,
ElementA_
,
StrideA_
,
ElementB_
,
StrideB_
,
TiledMma_
,
GmemTiledCopyA_
,
SmemLayoutAtomA_
,
SmemCopyAtomA_
,
TransformA_
,
GmemTiledCopyB_
,
SmemLayoutAtomB_
,
SmemCopyAtomB_
,
TransformB_
,
Activation_
,
SwapAB_
>
{
static
constexpr
bool
isGated
=
true
;
static
constexpr
bool
SwapAB
=
SwapAB_
;
//
// Type Aliases
//
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecialized
<
Stages
,
ClusterShape
,
KernelSchedule
>
;
using
TileShape
=
TileShape_
;
using
ElementA
=
ElementA_
;
using
StrideA
=
StrideA_
;
using
ElementB
=
ElementB_
;
using
StrideB
=
StrideB_
;
using
TiledMma
=
TiledMma_
;
using
ElementAccumulator
=
typename
TiledMma
::
ValTypeC
;
using
GmemTiledCopyA
=
GmemTiledCopyA_
;
using
GmemTiledCopyB
=
GmemTiledCopyB_
;
using
SmemLayoutAtomA
=
SmemLayoutAtomA_
;
using
SmemLayoutAtomB
=
SmemLayoutAtomB_
;
using
SmemCopyAtomA
=
SmemCopyAtomA_
;
using
SmemCopyAtomB
=
SmemCopyAtomB_
;
using
TransformA
=
TransformA_
;
using
TransformB
=
TransformB_
;
using
ArchTag
=
typename
DispatchPolicy
::
ArchTag
;
using
Activation
=
Activation_
<
ElementAccumulator
>
;
using
ElementAux
=
cute
::
conditional_t
<
SwapAB
,
ElementA_
,
ElementB_
>
;
using
ValTypeAux
=
cute
::
conditional_t
<
SwapAB
,
typename
TiledMma
::
ValTypeA
,
typename
TiledMma
::
ValTypeB
>
;
using
MainloopPipeline
=
cutlass
::
PipelineTmaAsync
<
DispatchPolicy
::
Stages
>
;
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
static_assert
(
cute
::
rank
(
SmemLayoutAtomA
{})
==
2
,
"SmemLayoutAtom must be rank 2 (M/N, K)"
);
static_assert
(
(
size
<
0
>
(
TileShape
{})
%
size
<
0
>
(
SmemLayoutAtomA
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
static_assert
(
(
size
<
2
>
(
TileShape
{})
%
size
<
1
>
(
SmemLayoutAtomA
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
static_assert
(
cute
::
rank
(
SmemLayoutAtomB
{})
==
2
,
"SmemLayoutAtom must be rank 2 (M/N, K)"
);
static_assert
(
(
size
<
1
>
(
TileShape
{})
%
size
<
0
>
(
SmemLayoutAtomB
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
static_assert
(
(
size
<
2
>
(
TileShape
{})
%
size
<
1
>
(
SmemLayoutAtomB
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
// Tile along modes in a way that maximizes the TMA box size.
using
SmemLayoutA
=
decltype
(
tile_to_shape
(
SmemLayoutAtomA
{},
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{}),
Int
<
DispatchPolicy
::
Stages
>
{}),
conditional_t
<::
cutlass
::
gemm
::
detail
::
is_major
<
0
,
StrideA
>
(),
Step
<
_2
,
_1
,
_3
>
,
Step
<
_1
,
_2
,
_3
>>
{}));
using
SmemLayoutB
=
decltype
(
tile_to_shape
(
SmemLayoutAtomB
{},
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{}),
Int
<
DispatchPolicy
::
Stages
>
{}),
conditional_t
<::
cutlass
::
gemm
::
detail
::
is_major
<
0
,
StrideB
>
(),
Step
<
_2
,
_1
,
_3
>
,
Step
<
_1
,
_2
,
_3
>>
{}));
using
SmemLayoutAux
=
cute
::
conditional_t
<
SwapAB
,
SmemLayoutA
,
SmemLayoutB
>
;
static_assert
(
DispatchPolicy
::
Stages
>=
2
,
"Specialization requires Stages set to value 2 or more."
);
static_assert
(
cute
::
is_base_of
<
cute
::
GMMA
::
DescriptorIterator
,
typename
TiledMma
::
FrgTypeA
>::
value
&&
cute
::
is_base_of
<
cute
::
GMMA
::
DescriptorIterator
,
typename
TiledMma
::
FrgTypeB
>::
value
,
"MMA atom must source both A and B operand from smem_desc for this mainloop."
);
static_assert
(
cute
::
is_same_v
<
GmemTiledCopyA
,
SM90_TMA_LOAD
>
||
cute
::
is_same_v
<
GmemTiledCopyA
,
SM90_TMA_LOAD_MULTICAST
>
,
"GmemTiledCopy - invalid SM90 TMA copy atom specified."
);
static_assert
(
cute
::
is_same_v
<
GmemTiledCopyB
,
SM90_TMA_LOAD
>
||
cute
::
is_same_v
<
GmemTiledCopyB
,
SM90_TMA_LOAD_MULTICAST
>
,
"GmemTiledCopy - invalid SM90 TMA copy atom specified."
);
// TMA converts f32 input to tf32 when copying from GMEM to SMEM
// For all other types, cast to size equivalent uint type to avoid any rounding by TMA.
static
constexpr
bool
ConvertF32toTF32A
=
cute
::
is_same_v
<
float
,
ElementA
>
;
static
constexpr
bool
ConvertF32toTF32B
=
cute
::
is_same_v
<
float
,
ElementB
>
;
using
InternalElementA
=
cute
::
conditional_t
<
ConvertF32toTF32A
,
tfloat32_t
,
uint_bit_t
<
sizeof_bits_v
<
ElementA
>>>
;
using
InternalElementB
=
cute
::
conditional_t
<
ConvertF32toTF32B
,
tfloat32_t
,
uint_bit_t
<
sizeof_bits_v
<
ElementB
>>>
;
using
InternalElementAux
=
cute
::
conditional_t
<
SwapAB
,
InternalElementA
,
InternalElementB
>
;
struct
SharedStorage
{
struct
TensorStorage
:
cute
::
aligned_struct
<
128
>
{
cute
::
array_aligned
<
typename
TiledMma
::
ValTypeA
,
cute
::
cosize_v
<
SmemLayoutA
>>
smem_A
;
cute
::
array_aligned
<
typename
TiledMma
::
ValTypeB
,
cute
::
cosize_v
<
SmemLayoutB
>>
smem_B
;
cute
::
array_aligned
<
ValTypeAux
,
cute
::
cosize_v
<
SmemLayoutAux
>>
smem_Aux
;
}
tensors
;
using
PipelineStorage
=
typename
MainloopPipeline
::
SharedStorage
;
PipelineStorage
pipeline
;
};
using
TensorStorage
=
typename
SharedStorage
::
TensorStorage
;
using
PipelineStorage
=
typename
SharedStorage
::
PipelineStorage
;
// Host side kernel arguments
struct
Arguments
{
ElementA
const
*
ptr_A
;
StrideA
dA
;
ElementB
const
*
ptr_B
;
StrideB
dB
;
float
scale_d0
=
1.0
f
;
float
scale_d1
=
1.0
f
;
uint32_t
mma_promotion_interval
=
4
;
};
// Device side kernel params
struct
Params
{
// Assumption: StrideA is congruent with Problem_MK
using
TMA_A
=
decltype
(
make_tma_copy
(
GmemTiledCopyA
{},
make_tensor
(
static_cast
<
InternalElementA
const
*>
(
nullptr
),
repeat_like
(
StrideA
{},
int32_t
(
0
)),
StrideA
{}),
SmemLayoutA
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
1
>
(
ClusterShape
{})));
// mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using
TMA_B
=
decltype
(
make_tma_copy
(
GmemTiledCopyB
{},
make_tensor
(
static_cast
<
InternalElementB
const
*>
(
nullptr
),
repeat_like
(
StrideB
{},
int32_t
(
0
)),
StrideB
{}),
SmemLayoutB
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
0
>
(
ClusterShape
{})));
// mcast along M mode for this N load, if any
using
TMA_Aux
=
cute
::
conditional_t
<
SwapAB
,
TMA_A
,
TMA_B
>
;
TMA_A
tma_load_a
;
TMA_B
tma_load_b
;
TMA_Aux
tma_load_aux
;
float
scale_d0
=
1.0
f
;
float
scale_d1
=
1.0
f
;
};
//
// Methods
//
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
(
void
)
workspace
;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
,
1
);
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
auto
ptr_A
=
reinterpret_cast
<
InternalElementA
const
*>
(
args
.
ptr_A
);
auto
ptr_B
=
reinterpret_cast
<
InternalElementB
const
*>
(
args
.
ptr_B
);
Tensor
tensor_a
=
make_tensor
(
ptr_A
,
make_layout
(
make_shape
(
M
,
K
,
L
),
args
.
dA
));
Tensor
tensor_b
=
make_tensor
(
ptr_B
,
make_layout
(
make_shape
(
N
,
K
,
L
),
args
.
dB
));
typename
Params
::
TMA_A
tma_load_a
=
make_tma_copy
(
GmemTiledCopyA
{},
tensor_a
,
SmemLayoutA
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
1
>
(
ClusterShape
{}));
// mcast along N mode for this M load, if any
typename
Params
::
TMA_B
tma_load_b
=
make_tma_copy
(
GmemTiledCopyB
{},
tensor_b
,
SmemLayoutB
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
if
constexpr
(
SwapAB
)
{
auto
ptr_Aux
=
reinterpret_cast
<
InternalElementA
const
*>
(
args
.
ptr_A
+
size
(
make_shape
(
M
,
K
,
L
)));
Tensor
tensor_aux
=
make_tensor
(
ptr_Aux
,
make_layout
(
make_shape
(
M
,
K
,
L
),
args
.
dA
));
typename
Params
::
TMA_Aux
tma_load_aux
=
make_tma_copy
(
GmemTiledCopyA
{},
tensor_aux
,
SmemLayoutA
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
1
>
(
ClusterShape
{}));
// mcast along N mode for this M load, if any
return
{
tma_load_a
,
tma_load_b
,
tma_load_aux
,
args
.
scale_d0
,
args
.
scale_d1
};
}
else
{
auto
ptr_Aux
=
reinterpret_cast
<
InternalElementB
const
*>
(
args
.
ptr_B
+
size
(
make_shape
(
N
,
K
,
L
)));
Tensor
tensor_aux
=
make_tensor
(
ptr_Aux
,
make_layout
(
make_shape
(
N
,
K
,
L
),
args
.
dB
));
typename
Params
::
TMA_Aux
tma_load_aux
=
make_tma_copy
(
GmemTiledCopyB
{},
tensor_aux
,
SmemLayoutB
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
return
{
tma_load_a
,
tma_load_b
,
tma_load_aux
,
args
.
scale_d0
,
args
.
scale_d1
};
}
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
[[
maybe_unused
]]
Arguments
const
&
args
)
{
constexpr
int
tma_alignment_bits
=
128
;
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
,
1
);
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
bool
implementable
=
true
;
constexpr
int
min_tma_aligned_elements_A
=
tma_alignment_bits
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
implementable
=
implementable
&&
cutlass
::
detail
::
check_alignment
<
min_tma_aligned_elements_A
>
(
cute
::
make_shape
(
M
,
K
,
L
),
StrideA
{});
constexpr
int
min_tma_aligned_elements_B
=
tma_alignment_bits
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
implementable
=
implementable
&&
cutlass
::
detail
::
check_alignment
<
min_tma_aligned_elements_B
>
(
cute
::
make_shape
(
N
,
K
,
L
),
StrideB
{});
if
(
!
implementable
)
{
CUTLASS_TRACE_HOST
(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.
\n
"
);
}
return
implementable
;
}
static
constexpr
int
K_PIPE_MAX
=
DispatchPolicy
::
Stages
;
static
constexpr
int
K_PIPE_MMAS
=
1
;
static
constexpr
uint32_t
TmaTransactionBytes
=
(
size
<
0
>
(
SmemLayoutA
{})
*
size
<
1
>
(
SmemLayoutA
{})
*
static_cast
<
uint32_t
>
(
sizeof_bits
<
ElementA
>::
value
))
/
8
+
(
size
<
0
>
(
SmemLayoutB
{})
*
size
<
1
>
(
SmemLayoutB
{})
*
static_cast
<
uint32_t
>
(
sizeof_bits
<
ElementB
>::
value
))
/
8
+
(
size
<
0
>
(
SmemLayoutAux
{})
*
size
<
1
>
(
SmemLayoutAux
{})
*
static_cast
<
uint32_t
>
(
sizeof_bits
<
ElementAux
>::
value
))
/
8
;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static
void
prefetch_tma_descriptors
(
Params
const
&
mainloop_params
)
{
cute
::
prefetch_tma_descriptor
(
mainloop_params
.
tma_load_a
.
get_tma_descriptor
());
cute
::
prefetch_tma_descriptor
(
mainloop_params
.
tma_load_b
.
get_tma_descriptor
());
cute
::
prefetch_tma_descriptor
(
mainloop_params
.
tma_load_aux
.
get_tma_descriptor
());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
/// The rest of the tensors can be specified as needed by this collective.
template
<
class
ProblemShape_MNKL
>
CUTLASS_DEVICE
auto
load_init
(
ProblemShape_MNKL
const
&
problem_shape_MNKL
,
Params
const
&
mainloop_params
)
const
{
using
X
=
Underscore
;
// Separate out problem shape for convenience
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor
mA_mkl
=
mainloop_params
.
tma_load_a
.
get_tma_tensor
(
make_shape
(
M
,
K
,
L
));
// (m,k,l)
Tensor
mB_nkl
=
mainloop_params
.
tma_load_b
.
get_tma_tensor
(
make_shape
(
N
,
K
,
L
));
// (n,k,l)
// Make tiled views, defer the slice
Tensor
gA_mkl
=
local_tile
(
mA_mkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
// (BLK_M,BLK_K,m,k,l)
Tensor
gB_nkl
=
local_tile
(
mB_nkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,n,k,l)
if
constexpr
(
SwapAB
)
{
Tensor
mAux_xkl
=
mainloop_params
.
tma_load_aux
.
get_tma_tensor
(
make_shape
(
M
,
K
,
L
));
// (m,k,l)
Tensor
gAux_xkl
=
local_tile
(
mAux_xkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
// (BLK_M,BLK_K,m,k,l)
return
cute
::
make_tuple
(
gA_mkl
,
gB_nkl
,
gAux_xkl
);
}
else
{
Tensor
mAux_xkl
=
mainloop_params
.
tma_load_aux
.
get_tma_tensor
(
make_shape
(
N
,
K
,
L
));
// (n,k,l)
Tensor
gAux_xkl
=
local_tile
(
mAux_xkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,n,k,l)
return
cute
::
make_tuple
(
gA_mkl
,
gB_nkl
,
gAux_xkl
);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template
<
class
TensorA
,
class
TensorB
,
class
TensorAux
,
class
KTileIterator
,
class
BlockCoord
>
CUTLASS_DEVICE
void
load
(
Params
const
&
mainloop_params
,
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_write
,
cute
::
tuple
<
TensorA
,
TensorB
,
TensorAux
>
const
&
load_inputs
,
BlockCoord
const
&
blk_coord
,
KTileIterator
k_tile_iter
,
int
k_tile_count
,
int
thread_idx
,
uint32_t
block_rank_in_cluster
,
TensorStorage
&
shared_tensors
)
{
int
lane_predicate
=
cute
::
elect_one_sync
();
if
(
lane_predicate
)
{
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_A
.
data
()),
SmemLayoutA
{});
// (BLK_M,BLK_K,PIPE)
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_B
.
data
()),
SmemLayoutB
{});
// (BLK_N,BLK_K,PIPE)
Tensor
sAux
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_Aux
.
data
()),
SmemLayoutAux
{});
//
// Prepare the TMA loads for A and B
//
constexpr
uint32_t
cluster_shape_x
=
get
<
0
>
(
typename
DispatchPolicy
::
ClusterShape
());
uint2
cluster_local_block_id
=
{
block_rank_in_cluster
%
cluster_shape_x
,
block_rank_in_cluster
/
cluster_shape_x
};
Tensor
gA_mkl
=
get
<
0
>
(
load_inputs
);
Tensor
gB_nkl
=
get
<
1
>
(
load_inputs
);
Tensor
gAux_xkl
=
get
<
2
>
(
load_inputs
);
auto
block_tma_a
=
mainloop_params
.
tma_load_a
.
get_slice
(
cluster_local_block_id
.
y
);
auto
block_tma_b
=
mainloop_params
.
tma_load_b
.
get_slice
(
cluster_local_block_id
.
x
);
auto
block_tma_aux
=
SwapAB
?
mainloop_params
.
tma_load_aux
.
get_slice
(
cluster_local_block_id
.
y
)
:
mainloop_params
.
tma_load_aux
.
get_slice
(
cluster_local_block_id
.
x
);
// Partition the inputs based on the current block coordinates.
auto
[
m_coord
,
n_coord
,
k_coord
,
l_coord
]
=
blk_coord
;
Tensor
gA
=
gA_mkl
(
_
,
_
,
m_coord
,
_
,
l_coord
);
// (BLK_M,BLK_K,k)
Tensor
gB
=
gB_nkl
(
_
,
_
,
n_coord
,
_
,
l_coord
);
// (BLK_N,BLK_K,k)
Tensor
gAux
=
SwapAB
?
gAux_xkl
(
_
,
_
,
m_coord
,
_
,
l_coord
)
:
gAux_xkl
(
_
,
_
,
n_coord
,
_
,
l_coord
);
// Applies the mapping from block_tma_a
Tensor
tAgA
=
block_tma_a
.
partition_S
(
gA
);
// (TMA,TMA_M,TMA_K,k)
Tensor
tAsA
=
block_tma_a
.
partition_D
(
sA
);
// (TMA,TMA_M,TMA_K,PIPE)
Tensor
tBgB
=
block_tma_b
.
partition_S
(
gB
);
// (TMA,TMA_N,TMA_K,k)
Tensor
tBsB
=
block_tma_b
.
partition_D
(
sB
);
// (TMA,TMA_N,TMA_K,PIPE)
Tensor
tAuxgAux
=
block_tma_aux
.
partition_S
(
gAux
);
Tensor
tAuxsAux
=
block_tma_aux
.
partition_D
(
sAux
);
uint16_t
mcast_mask_a
=
0
;
uint16_t
mcast_mask_b
=
0
;
uint16_t
mcast_mask_aux
=
0
;
// Issue TmaLoads
// Maps the tile -> block, value
if
constexpr
(
cute
::
is_same_v
<
GmemTiledCopyA
,
SM90_TMA_LOAD_MULTICAST
>
)
{
auto
block_layout
=
Layout
<
typename
DispatchPolicy
::
ClusterShape
>
{};
// (m,n) -> block_id
for
(
int
n
=
0
;
n
<
size
<
1
>
(
block_layout
);
++
n
)
{
mcast_mask_a
|=
(
uint16_t
(
1
)
<<
block_layout
(
cluster_local_block_id
.
x
,
n
,
Int
<
0
>
{}));
}
}
if
constexpr
(
cute
::
is_same_v
<
GmemTiledCopyB
,
SM90_TMA_LOAD_MULTICAST
>
)
{
auto
block_layout
=
Layout
<
typename
DispatchPolicy
::
ClusterShape
>
{};
// (m,n) -> block_id
for
(
int
m
=
0
;
m
<
size
<
0
>
(
block_layout
);
++
m
)
{
mcast_mask_b
|=
(
uint16_t
(
1
)
<<
block_layout
(
m
,
cluster_local_block_id
.
y
,
Int
<
0
>
{}));
}
}
if
constexpr
(
SwapAB
)
{
mcast_mask_aux
=
mcast_mask_a
;
}
else
{
mcast_mask_aux
=
mcast_mask_b
;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for
(;
k_tile_count
>
0
;
--
k_tile_count
)
{
// LOCK smem_pipe_write for _writing_
pipeline
.
producer_acquire
(
smem_pipe_write
);
//
// Copy gmem to smem for *k_tile_iter
//
using
BarrierType
=
typename
MainloopPipeline
::
ProducerBarrierType
;
BarrierType
*
tma_barrier
=
pipeline
.
producer_get_barrier
(
smem_pipe_write
);
int
write_stage
=
smem_pipe_write
.
index
();
copy
(
mainloop_params
.
tma_load_a
.
with
(
*
tma_barrier
,
mcast_mask_a
),
tAgA
(
_
,
_
,
_
,
*
k_tile_iter
),
tAsA
(
_
,
_
,
_
,
write_stage
));
copy
(
mainloop_params
.
tma_load_b
.
with
(
*
tma_barrier
,
mcast_mask_b
),
tBgB
(
_
,
_
,
_
,
*
k_tile_iter
),
tBsB
(
_
,
_
,
_
,
write_stage
));
copy
(
mainloop_params
.
tma_load_aux
.
with
(
*
tma_barrier
,
mcast_mask_aux
),
tAuxgAux
(
_
,
_
,
_
,
*
k_tile_iter
),
tAuxsAux
(
_
,
_
,
_
,
write_stage
));
++
k_tile_iter
;
// Advance smem_pipe_write
++
smem_pipe_write
;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE
void
load_tail
(
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_write
)
{
int
lane_predicate
=
cute
::
elect_one_sync
();
// Issue the epilogue waits
if
(
lane_predicate
)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline
.
producer_tail
(
smem_pipe_write
);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template
<
class
FrgTensorC
>
CUTLASS_DEVICE
void
mma
(
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_read
,
FrgTensorC
&
accum0
,
FrgTensorC
&
accum1
,
int
k_tile_count
,
int
thread_idx
,
TensorStorage
&
shared_tensors
,
Params
const
&
mainloop_params
)
{
static_assert
(
is_rmem
<
FrgTensorC
>::
value
,
"C tensor must be rmem resident."
);
static_assert
(
cute
::
rank
(
SmemLayoutA
{})
==
3
,
"Smem layout must be rank 3."
);
static_assert
(
cute
::
rank
(
SmemLayoutB
{})
==
3
,
"Smem layout must be rank 3."
);
static_assert
(
cute
::
rank
(
SmemLayoutAux
{})
==
3
,
"Smem layout must be rank 3."
);
static_assert
(
cute
::
is_void_v
<
SmemCopyAtomA
>
,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."
);
static_assert
(
cute
::
is_void_v
<
SmemCopyAtomB
>
,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."
);
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_A
.
data
()),
SmemLayoutA
{});
// (BLK_M,BLK_K,PIPE)
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_B
.
data
()),
SmemLayoutB
{});
// (BLK_N,BLK_K,PIPE)
Tensor
sAux
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_Aux
.
data
()),
SmemLayoutAux
{});
//
// Define C accumulators and A/B partitioning
//
TiledMma
tiled_mma
;
auto
thread_mma
=
tiled_mma
.
get_thread_slice
(
thread_idx
);
Tensor
tCsA
=
thread_mma
.
partition_A
(
sA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCsB
=
thread_mma
.
partition_B
(
sB
);
// (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor
tCrA
=
thread_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCrB
=
thread_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_N,MMA_K,PIPE)
auto
tCsAux
=
[
&
]()
->
auto
{
if
constexpr
(
SwapAB
)
{
return
thread_mma
.
partition_A
(
sAux
);
}
else
{
return
thread_mma
.
partition_B
(
sAux
);
}
}();
auto
tCrAux
=
[
&
]()
->
auto
{
if
constexpr
(
SwapAB
)
{
return
thread_mma
.
make_fragment_A
(
tCsAux
);
}
else
{
return
thread_mma
.
make_fragment_B
(
tCsAux
);
}
}();
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
accum0
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
2
>
(
accum0
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCsA
)
==
size
<
2
>
(
tCsB
));
// K
CUTE_STATIC_ASSERT_V
(
size
<
3
>
(
tCsA
)
==
size
<
3
>
(
tCsB
));
// PIPE
if
constexpr
(
SwapAB
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsAux
)
==
size
<
1
>
(
accum1
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
2
>
(
accum1
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCsB
)
==
size
<
2
>
(
tCsAux
));
// K
CUTE_STATIC_ASSERT_V
(
size
<
3
>
(
tCsB
)
==
size
<
3
>
(
tCsAux
));
// PIPE
}
else
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
accum1
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsAux
)
==
size
<
2
>
(
accum1
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCsA
)
==
size
<
2
>
(
tCsAux
));
// K
CUTE_STATIC_ASSERT_V
(
size
<
3
>
(
tCsA
)
==
size
<
3
>
(
tCsAux
));
// PIPE
}
CUTE_STATIC_ASSERT_V
(
Int
<
DispatchPolicy
::
Stages
>
{}
==
size
<
2
>
(
sA
));
// PIPE
CUTE_STATIC_ASSERT_V
(
Int
<
DispatchPolicy
::
Stages
>
{}
==
size
<
2
>
(
sB
));
// PIPE
CUTE_STATIC_ASSERT_V
(
Int
<
DispatchPolicy
::
Stages
>
{}
==
size
<
2
>
(
sAux
));
// PIPE
//
// PIPELINED MAIN LOOP
//
static_assert
((
0
<=
K_PIPE_MMAS
)
&&
(
K_PIPE_MMAS
<
K_PIPE_MAX
),
"ERROR : Incorrect number of MMAs in flight"
);
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState
smem_pipe_release
=
smem_pipe_read
;
// Prologue GMMAs
int
prologue_mma_count
=
min
(
K_PIPE_MMAS
,
k_tile_count
);
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
Zero
;
warpgroup_fence_operand
(
accum0
);
warpgroup_fence_operand
(
accum1
);
CUTLASS_PRAGMA_UNROLL
for
(
int
k_tile_prologue
=
prologue_mma_count
;
k_tile_prologue
>
0
;
--
k_tile_prologue
)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto
barrier_token
=
pipeline
.
consumer_try_wait
(
smem_pipe_read
);
pipeline
.
consumer_wait
(
smem_pipe_read
,
barrier_token
);
int
read_stage
=
smem_pipe_read
.
index
();
warpgroup_arrive
();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
);
++
k_block
)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accum0
);
if
constexpr
(
SwapAB
)
{
cute
::
gemm
(
tiled_mma
,
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accum1
);
}
else
{
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
accum1
);
}
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
One
;
}
warpgroup_commit_batch
();
++
smem_pipe_read
;
}
warpgroup_fence_operand
(
accum0
);
warpgroup_fence_operand
(
accum1
);
// Mainloop GMMAs
k_tile_count
-=
prologue_mma_count
;
CUTLASS_PRAGMA_NO_UNROLL
for
(;
k_tile_count
>
0
;
--
k_tile_count
)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto
barrier_token
=
pipeline
.
consumer_try_wait
(
smem_pipe_read
);
pipeline
.
consumer_wait
(
smem_pipe_read
,
barrier_token
);
//
// Compute on k_tile
//
int
read_stage
=
smem_pipe_read
.
index
();
warpgroup_fence_operand
(
accum0
);
warpgroup_fence_operand
(
accum1
);
warpgroup_arrive
();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
);
++
k_block
)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accum0
);
if
constexpr
(
SwapAB
)
{
cute
::
gemm
(
tiled_mma
,
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accum1
);
}
else
{
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
accum1
);
}
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
One
;
}
warpgroup_commit_batch
();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait
<
K_PIPE_MMAS
>
();
warpgroup_fence_operand
(
accum0
);
warpgroup_fence_operand
(
accum1
);
// UNLOCK smem_pipe_release, done _computing_ on it
pipeline
.
consumer_release
(
smem_pipe_release
);
// Advance smem_pipe_read and smem_pipe_release
++
smem_pipe_read
;
++
smem_pipe_release
;
}
warpgroup_fence_operand
(
accum0
);
warpgroup_fence_operand
(
accum1
);
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE
void
mma_tail
(
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_release
,
int
k_tile_count
)
{
// Prologue GMMAs
int
prologue_mma_count
=
min
(
K_PIPE_MMAS
,
k_tile_count
);
k_tile_count
-=
prologue_mma_count
;
smem_pipe_release
.
advance
(
k_tile_count
);
// Wait on all GMMAs to complete
warpgroup_wait
<
0
>
();
for
(
int
count
=
0
;
count
<
prologue_mma_count
;
++
count
)
{
pipeline
.
consumer_release
(
smem_pipe_release
);
// UNLOCK smem_pipe_release, done _computing_ on it
++
smem_pipe_release
;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_gated_tma_gmma_ss_warpspecialized_fp8.hpp
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
#pragma once
#include "cute/arch/cluster_sm90.hpp"
#include "cute/arch/copy_sm90.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/dispatch_policy.hpp"
#include "cute/algorithm/functional.hpp"
#include "cute/algorithm/gemm.hpp"
#include "cute/atom/mma_atom.hpp"
#include "cute/numeric/arithmetic_tuple.hpp"
#include "cute/tensor_predicate.hpp"
#include "cutlass/epilogue/thread/activation.h"
#include "cutlass/gemm/collective/fp8_accumulation.hpp"
#include "cutlass/trace.h"
/////////////////////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
::
gemm
::
collective
{
using
namespace
cute
;
/////////////////////////////////////////////////////////////////////////////////////////////////
// WarpSpecialized Mainloop
template
<
int
Stages
,
class
ClusterShape
,
class
KernelSchedule
,
class
TileShape_
,
class
ElementA_
,
class
StrideA_
,
class
ElementB_
,
class
StrideB_
,
class
TiledMma_
,
class
GmemTiledCopyA_
,
class
SmemLayoutAtomA_
,
class
SmemCopyAtomA_
,
class
TransformA_
,
class
GmemTiledCopyB_
,
class
SmemLayoutAtomB_
,
class
SmemCopyAtomB_
,
class
TransformB_
,
template
<
class
/* ElementCompute */
>
class
Activation_
,
bool
SwapAB_
>
struct
CollectiveMmaGated
<
MainloopSm90TmaGmmaWarpSpecializedFP8
<
Stages
,
ClusterShape
,
KernelSchedule
>
,
TileShape_
,
ElementA_
,
StrideA_
,
ElementB_
,
StrideB_
,
TiledMma_
,
GmemTiledCopyA_
,
SmemLayoutAtomA_
,
SmemCopyAtomA_
,
TransformA_
,
GmemTiledCopyB_
,
SmemLayoutAtomB_
,
SmemCopyAtomB_
,
TransformB_
,
Activation_
,
SwapAB_
>
{
static
constexpr
bool
isGated
=
true
;
static
constexpr
bool
SwapAB
=
SwapAB_
;
//
// Type Aliases
//
using
DispatchPolicy
=
MainloopSm90TmaGmmaWarpSpecializedFP8
<
Stages
,
ClusterShape
,
KernelSchedule
>
;
using
TileShape
=
TileShape_
;
using
ElementA
=
ElementA_
;
using
StrideA
=
StrideA_
;
using
ElementB
=
ElementB_
;
using
StrideB
=
StrideB_
;
using
TiledMma
=
TiledMma_
;
using
ElementAccumulator
=
typename
TiledMma
::
ValTypeC
;
using
GmemTiledCopyA
=
GmemTiledCopyA_
;
using
GmemTiledCopyB
=
GmemTiledCopyB_
;
using
SmemLayoutAtomA
=
SmemLayoutAtomA_
;
using
SmemLayoutAtomB
=
SmemLayoutAtomB_
;
using
SmemCopyAtomA
=
SmemCopyAtomA_
;
using
SmemCopyAtomB
=
SmemCopyAtomB_
;
using
TransformA
=
TransformA_
;
using
TransformB
=
TransformB_
;
using
ArchTag
=
typename
DispatchPolicy
::
ArchTag
;
using
Activation
=
Activation_
<
ElementAccumulator
>
;
using
ElementAux
=
cute
::
conditional_t
<
SwapAB
,
ElementA_
,
ElementB_
>
;
using
ValTypeAux
=
cute
::
conditional_t
<
SwapAB
,
typename
TiledMma
::
ValTypeA
,
typename
TiledMma
::
ValTypeB
>
;
using
MainloopPipeline
=
cutlass
::
PipelineTmaAsync
<
DispatchPolicy
::
Stages
>
;
using
PipelineState
=
cutlass
::
PipelineState
<
DispatchPolicy
::
Stages
>
;
using
PipelineParams
=
typename
MainloopPipeline
::
Params
;
static_assert
(
cute
::
rank
(
SmemLayoutAtomA
{})
==
2
,
"SmemLayoutAtom must be rank 2 (M/N, K)"
);
static_assert
(
(
size
<
0
>
(
TileShape
{})
%
size
<
0
>
(
SmemLayoutAtomA
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
static_assert
(
(
size
<
2
>
(
TileShape
{})
%
size
<
1
>
(
SmemLayoutAtomA
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
static_assert
(
cute
::
rank
(
SmemLayoutAtomB
{})
==
2
,
"SmemLayoutAtom must be rank 2 (M/N, K)"
);
static_assert
(
(
size
<
1
>
(
TileShape
{})
%
size
<
0
>
(
SmemLayoutAtomB
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
static_assert
(
(
size
<
2
>
(
TileShape
{})
%
size
<
1
>
(
SmemLayoutAtomB
{}))
==
0
,
"SmemLayoutAtom must evenly divide tile shape."
);
// Tile along modes in a way that maximizes the TMA box size.
using
SmemLayoutA
=
decltype
(
tile_to_shape
(
SmemLayoutAtomA
{},
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{}),
Int
<
DispatchPolicy
::
Stages
>
{}),
conditional_t
<::
cutlass
::
gemm
::
detail
::
is_major
<
0
,
StrideA
>
(),
Step
<
_2
,
_1
,
_3
>
,
Step
<
_1
,
_2
,
_3
>>
{}));
using
SmemLayoutB
=
decltype
(
tile_to_shape
(
SmemLayoutAtomB
{},
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{}),
Int
<
DispatchPolicy
::
Stages
>
{}),
conditional_t
<::
cutlass
::
gemm
::
detail
::
is_major
<
0
,
StrideB
>
(),
Step
<
_2
,
_1
,
_3
>
,
Step
<
_1
,
_2
,
_3
>>
{}));
using
SmemLayoutAux
=
cute
::
conditional_t
<
SwapAB
,
SmemLayoutA
,
SmemLayoutB
>
;
static_assert
(
DispatchPolicy
::
Stages
>=
2
,
"Specialization requires Stages set to value 1 or more."
);
static_assert
(
cute
::
is_base_of
<
cute
::
GMMA
::
DescriptorIterator
,
typename
TiledMma
::
FrgTypeA
>::
value
&&
cute
::
is_base_of
<
cute
::
GMMA
::
DescriptorIterator
,
typename
TiledMma
::
FrgTypeB
>::
value
,
"MMA atom must source both A and B operand from smem_desc for this mainloop."
);
static_assert
(
cute
::
is_same_v
<
GmemTiledCopyA
,
SM90_TMA_LOAD
>
||
cute
::
is_same_v
<
GmemTiledCopyA
,
SM90_TMA_LOAD_MULTICAST
>
,
"GmemTiledCopy - invalid SM90 TMA copy atom specified."
);
static_assert
(
cute
::
is_same_v
<
GmemTiledCopyB
,
SM90_TMA_LOAD
>
||
cute
::
is_same_v
<
GmemTiledCopyB
,
SM90_TMA_LOAD_MULTICAST
>
,
"GmemTiledCopy - invalid SM90 TMA copy atom specified."
);
struct
SharedStorage
{
struct
TensorStorage
:
cute
::
aligned_struct
<
128
>
{
cute
::
array_aligned
<
typename
TiledMma
::
ValTypeA
,
cute
::
cosize_v
<
SmemLayoutA
>>
smem_A
;
cute
::
array_aligned
<
typename
TiledMma
::
ValTypeB
,
cute
::
cosize_v
<
SmemLayoutB
>>
smem_B
;
cute
::
array_aligned
<
ValTypeAux
,
cute
::
cosize_v
<
SmemLayoutAux
>>
smem_Aux
;
}
tensors
;
using
PipelineStorage
=
typename
MainloopPipeline
::
SharedStorage
;
PipelineStorage
pipeline
;
};
using
TensorStorage
=
typename
SharedStorage
::
TensorStorage
;
using
PipelineStorage
=
typename
SharedStorage
::
PipelineStorage
;
// Host side kernel arguments
struct
Arguments
{
ElementA
const
*
ptr_A
;
StrideA
dA
;
ElementB
const
*
ptr_B
;
StrideB
dB
;
float
scale_d0
=
1.0
f
;
float
scale_d1
=
1.0
f
;
uint32_t
mma_promotion_interval
=
4
;
};
// Device side kernel params
struct
Params
{
// Assumption: StrideA is congruent with Problem_MK
using
TMA_A
=
decltype
(
make_tma_copy
(
GmemTiledCopyA
{},
make_tensor
(
static_cast
<
ElementA
const
*>
(
nullptr
),
repeat_like
(
StrideA
{},
int32_t
(
0
)),
StrideA
{}),
SmemLayoutA
{}(
_
,
_
,
0
),
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
1
>
(
ClusterShape
{})));
// mcast along N mode for this M load, if any
// Assumption: StrideB is congruent with Problem_NK
using
TMA_B
=
decltype
(
make_tma_copy
(
GmemTiledCopyB
{},
make_tensor
(
static_cast
<
ElementB
const
*>
(
nullptr
),
repeat_like
(
StrideB
{},
int32_t
(
0
)),
StrideB
{}),
SmemLayoutB
{}(
_
,
_
,
0
),
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
0
>
(
ClusterShape
{})));
// mcast along M mode for this N load, if any
using
TMA_Aux
=
cute
::
conditional_t
<
SwapAB
,
TMA_A
,
TMA_B
>
;
TMA_A
tma_load_a
;
TMA_B
tma_load_b
;
TMA_Aux
tma_load_aux
;
float
scale_d0
=
1.0
f
;
float
scale_d1
=
1.0
f
;
uint32_t
mma_promotion_interval
=
4
;
};
//
// Methods
//
template
<
class
ProblemShape
>
static
constexpr
Params
to_underlying_arguments
(
ProblemShape
const
&
problem_shape
,
Arguments
const
&
args
,
void
*
workspace
)
{
(
void
)
workspace
;
// Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK)
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
,
1
);
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
auto
ptr_A
=
reinterpret_cast
<
ElementA
const
*>
(
args
.
ptr_A
);
auto
ptr_B
=
reinterpret_cast
<
ElementB
const
*>
(
args
.
ptr_B
);
Tensor
tensor_a
=
make_tensor
(
ptr_A
,
make_layout
(
make_shape
(
M
,
K
,
L
),
args
.
dA
));
Tensor
tensor_b
=
make_tensor
(
ptr_B
,
make_layout
(
make_shape
(
N
,
K
,
L
),
args
.
dB
));
typename
Params
::
TMA_A
tma_load_a
=
make_tma_copy
(
GmemTiledCopyA
{},
tensor_a
,
SmemLayoutA
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
1
>
(
ClusterShape
{}));
// mcast along N mode for this M load, if any
typename
Params
::
TMA_B
tma_load_b
=
make_tma_copy
(
GmemTiledCopyB
{},
tensor_b
,
SmemLayoutB
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
if
constexpr
(
SwapAB
)
{
auto
ptr_Aux
=
reinterpret_cast
<
ElementA
const
*>
(
args
.
ptr_A
+
size
(
make_shape
(
M
,
K
,
L
)));
Tensor
tensor_aux
=
make_tensor
(
ptr_Aux
,
make_layout
(
make_shape
(
M
,
K
,
L
),
args
.
dA
));
typename
Params
::
TMA_Aux
tma_load_aux
=
make_tma_copy
(
GmemTiledCopyA
{},
tensor_aux
,
SmemLayoutA
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
0
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
1
>
(
ClusterShape
{}));
// mcast along N mode for this M load, if any
return
{
tma_load_a
,
tma_load_b
,
tma_load_aux
,
args
.
scale_d0
,
args
.
scale_d1
,
args
.
mma_promotion_interval
};
}
else
{
auto
ptr_Aux
=
reinterpret_cast
<
ElementB
const
*>
(
args
.
ptr_B
+
size
(
make_shape
(
N
,
K
,
L
)));
Tensor
tensor_aux
=
make_tensor
(
ptr_Aux
,
make_layout
(
make_shape
(
N
,
K
,
L
),
args
.
dB
));
typename
Params
::
TMA_Aux
tma_load_aux
=
make_tma_copy
(
GmemTiledCopyB
{},
tensor_aux
,
SmemLayoutB
{}(
_
,
_
,
cute
::
Int
<
0
>
{}),
make_shape
(
shape
<
1
>
(
TileShape
{}),
shape
<
2
>
(
TileShape
{})),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
return
{
tma_load_a
,
tma_load_b
,
tma_load_aux
,
args
.
scale_d0
,
args
.
scale_d1
,
args
.
mma_promotion_interval
};
}
}
template
<
class
ProblemShape
>
static
bool
can_implement
(
ProblemShape
const
&
problem_shape
,
[[
maybe_unused
]]
Arguments
const
&
args
)
{
constexpr
int
tma_alignment_bits
=
128
;
auto
problem_shape_MNKL
=
append
<
4
>
(
problem_shape
,
1
);
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
bool
implementable
=
true
;
constexpr
int
min_tma_aligned_elements_A
=
tma_alignment_bits
/
cutlass
::
sizeof_bits
<
ElementA
>::
value
;
implementable
=
implementable
&&
cutlass
::
detail
::
check_alignment
<
min_tma_aligned_elements_A
>
(
cute
::
make_shape
(
M
,
K
,
L
),
StrideA
{});
constexpr
int
min_tma_aligned_elements_B
=
tma_alignment_bits
/
cutlass
::
sizeof_bits
<
ElementB
>::
value
;
implementable
=
implementable
&&
cutlass
::
detail
::
check_alignment
<
min_tma_aligned_elements_B
>
(
cute
::
make_shape
(
N
,
K
,
L
),
StrideB
{});
/* MMA promotion interval should be a multiple of 4, since each mainloop iteration would issue 4 MMA
* instructions. */
implementable
=
implementable
&&
(
args
.
mma_promotion_interval
%
4
==
0
);
if
(
!
implementable
)
{
CUTLASS_TRACE_HOST
(
" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.
\n
"
);
}
return
implementable
;
}
static
constexpr
int
K_PIPE_MAX
=
DispatchPolicy
::
Stages
;
static
constexpr
int
K_PIPE_MMAS
=
1
;
static
constexpr
uint32_t
TmaTransactionBytes
=
(
size
<
0
>
(
SmemLayoutA
{})
*
size
<
1
>
(
SmemLayoutA
{})
*
static_cast
<
uint32_t
>
(
sizeof_bits
<
ElementA
>::
value
))
/
8
+
(
size
<
0
>
(
SmemLayoutB
{})
*
size
<
1
>
(
SmemLayoutB
{})
*
static_cast
<
uint32_t
>
(
sizeof_bits
<
ElementB
>::
value
))
/
8
+
(
size
<
0
>
(
SmemLayoutAux
{})
*
size
<
1
>
(
SmemLayoutAux
{})
*
static_cast
<
uint32_t
>
(
sizeof_bits
<
ElementAux
>::
value
))
/
8
;
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
CUTLASS_DEVICE
static
void
prefetch_tma_descriptors
(
Params
const
&
mainloop_params
)
{
cute
::
prefetch_tma_descriptor
(
mainloop_params
.
tma_load_a
.
get_tma_descriptor
());
cute
::
prefetch_tma_descriptor
(
mainloop_params
.
tma_load_b
.
get_tma_descriptor
());
cute
::
prefetch_tma_descriptor
(
mainloop_params
.
tma_load_aux
.
get_tma_descriptor
());
}
/// Set up the data needed by this collective for load and mma.
/// Returns a tuple of tensors. The collective and the kernel layer have the contract
/// Returned tuple must contain at least two elements, with the first two elements being:
/// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l)
/// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l)
/// gAux_xkl - The tma tensor, A/B after a local tile so it has shape (BLK_N,BLK_K,m/n,k,l)
template
<
class
ProblemShape_MNKL
>
CUTLASS_DEVICE
auto
load_init
(
ProblemShape_MNKL
const
&
problem_shape_MNKL
,
Params
const
&
mainloop_params
)
const
{
using
X
=
Underscore
;
// Separate out problem shape for convenience
auto
[
M
,
N
,
K
,
L
]
=
problem_shape_MNKL
;
// TMA requires special handling of strides to deal with coord codomain mapping
// Represent the full tensors -- get these from TMA
Tensor
mA_mkl
=
mainloop_params
.
tma_load_a
.
get_tma_tensor
(
make_shape
(
M
,
K
,
L
));
// (m,k,l)
Tensor
mB_nkl
=
mainloop_params
.
tma_load_b
.
get_tma_tensor
(
make_shape
(
N
,
K
,
L
));
// (n,k,l)
// Make tiled views, defer the slice
Tensor
gA_mkl
=
local_tile
(
mA_mkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
// (BLK_M,BLK_K,m,k,l)
Tensor
gB_nkl
=
local_tile
(
mB_nkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,n,k,l)
if
constexpr
(
SwapAB
)
{
Tensor
mAux_xkl
=
mainloop_params
.
tma_load_aux
.
get_tma_tensor
(
make_shape
(
M
,
K
,
L
));
// (m,k,l)
Tensor
gAux_xkl
=
local_tile
(
mAux_xkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
_1
,
X
,
_1
>
{});
// (BLK_M,BLK_K,m,k,l)
return
cute
::
make_tuple
(
gA_mkl
,
gB_nkl
,
gAux_xkl
);
}
else
{
Tensor
mAux_xkl
=
mainloop_params
.
tma_load_aux
.
get_tma_tensor
(
make_shape
(
N
,
K
,
L
));
// (n,k,l)
Tensor
gAux_xkl
=
local_tile
(
mAux_xkl
,
TileShape
{},
make_coord
(
_
,
_
,
_
),
Step
<
X
,
_1
,
_1
>
{});
// (BLK_N,BLK_K,n,k,l)
return
cute
::
make_tuple
(
gA_mkl
,
gB_nkl
,
gAux_xkl
);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Producer Perspective
template
<
class
TensorA
,
class
TensorB
,
class
TensorAux
,
class
KTileIterator
,
class
BlockCoord
>
CUTLASS_DEVICE
void
load
(
Params
const
&
mainloop_params
,
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_write
,
cute
::
tuple
<
TensorA
,
TensorB
,
TensorAux
>
const
&
load_inputs
,
BlockCoord
const
&
blk_coord
,
KTileIterator
k_tile_iter
,
int
k_tile_count
,
int
thread_idx
,
uint32_t
block_rank_in_cluster
,
TensorStorage
&
shared_tensors
)
{
int
lane_predicate
=
cute
::
elect_one_sync
();
if
(
lane_predicate
)
{
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_A
.
data
()),
SmemLayoutA
{});
// (BLK_M,BLK_K,PIPE)
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_B
.
data
()),
SmemLayoutB
{});
// (BLK_N,BLK_K,PIPE)
Tensor
sAux
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_Aux
.
data
()),
SmemLayoutAux
{});
//
// Prepare the TMA loads for A and B
//
constexpr
uint32_t
cluster_shape_x
=
get
<
0
>
(
ClusterShape
());
uint2
cluster_local_block_id
=
{
block_rank_in_cluster
%
cluster_shape_x
,
block_rank_in_cluster
/
cluster_shape_x
};
Tensor
gA_mkl
=
get
<
0
>
(
load_inputs
);
Tensor
gB_nkl
=
get
<
1
>
(
load_inputs
);
Tensor
gAux_xkl
=
get
<
2
>
(
load_inputs
);
auto
block_tma_a
=
mainloop_params
.
tma_load_a
.
get_slice
(
cluster_local_block_id
.
y
);
auto
block_tma_b
=
mainloop_params
.
tma_load_b
.
get_slice
(
cluster_local_block_id
.
x
);
auto
block_tma_aux
=
SwapAB
?
mainloop_params
.
tma_load_aux
.
get_slice
(
cluster_local_block_id
.
y
)
:
mainloop_params
.
tma_load_aux
.
get_slice
(
cluster_local_block_id
.
x
);
// Partition the inputs based on the current block coordinates.
auto
[
m_coord
,
n_coord
,
k_coord
,
l_coord
]
=
blk_coord
;
Tensor
gA
=
gA_mkl
(
_
,
_
,
m_coord
,
_
,
l_coord
);
// (BLK_M,BLK_K,k)
Tensor
gB
=
gB_nkl
(
_
,
_
,
n_coord
,
_
,
l_coord
);
// (BLK_N,BLK_K,k)
Tensor
gAux
=
SwapAB
?
gAux_xkl
(
_
,
_
,
m_coord
,
_
,
l_coord
)
:
gAux_xkl
(
_
,
_
,
n_coord
,
_
,
l_coord
);
// Applies the mapping from block_tma_a
Tensor
tAgA
=
block_tma_a
.
partition_S
(
gA
);
// (TMA,TMA_M,TMA_K,k)
Tensor
tAsA
=
block_tma_a
.
partition_D
(
sA
);
// (TMA,TMA_M,TMA_K,PIPE)
Tensor
tBgB
=
block_tma_b
.
partition_S
(
gB
);
// (TMA,TMA_N,TMA_K,k)
Tensor
tBsB
=
block_tma_b
.
partition_D
(
sB
);
// (TMA,TMA_N,TMA_K,PIPE)
Tensor
tAuxgAux
=
block_tma_aux
.
partition_S
(
gAux
);
Tensor
tAuxsAux
=
block_tma_aux
.
partition_D
(
sAux
);
uint16_t
mcast_mask_a
=
0
;
uint16_t
mcast_mask_b
=
0
;
uint16_t
mcast_mask_aux
=
0
;
// Issue TmaLoads
// Maps the tile -> block, value
if
constexpr
(
cute
::
is_same_v
<
GmemTiledCopyA
,
SM90_TMA_LOAD_MULTICAST
>
)
{
auto
block_layout
=
Layout
<
typename
DispatchPolicy
::
ClusterShape
>
{};
// (m,n) -> block_id
for
(
int
n
=
0
;
n
<
size
<
1
>
(
block_layout
);
++
n
)
{
mcast_mask_a
|=
(
uint16_t
(
1
)
<<
block_layout
(
cluster_local_block_id
.
x
,
n
,
Int
<
0
>
{}));
}
}
if
constexpr
(
cute
::
is_same_v
<
GmemTiledCopyB
,
SM90_TMA_LOAD_MULTICAST
>
)
{
auto
block_layout
=
Layout
<
typename
DispatchPolicy
::
ClusterShape
>
{};
// (m,n) -> block_id
for
(
int
m
=
0
;
m
<
size
<
0
>
(
block_layout
);
++
m
)
{
mcast_mask_b
|=
(
uint16_t
(
1
)
<<
block_layout
(
m
,
cluster_local_block_id
.
y
,
Int
<
0
>
{}));
}
}
if
constexpr
(
SwapAB
)
{
mcast_mask_aux
=
mcast_mask_a
;
}
else
{
mcast_mask_aux
=
mcast_mask_b
;
}
// Mainloop
CUTLASS_PRAGMA_NO_UNROLL
for
(;
k_tile_count
>
0
;
--
k_tile_count
)
{
// LOCK smem_pipe_write for _writing_
pipeline
.
producer_acquire
(
smem_pipe_write
);
//
// Copy gmem to smem for *k_tile_iter
//
using
BarrierType
=
typename
MainloopPipeline
::
ProducerBarrierType
;
BarrierType
*
tma_barrier
=
pipeline
.
producer_get_barrier
(
smem_pipe_write
);
int
write_stage
=
smem_pipe_write
.
index
();
copy
(
mainloop_params
.
tma_load_a
.
with
(
*
tma_barrier
,
mcast_mask_a
),
tAgA
(
_
,
_
,
_
,
*
k_tile_iter
),
tAsA
(
_
,
_
,
_
,
write_stage
));
copy
(
mainloop_params
.
tma_load_b
.
with
(
*
tma_barrier
,
mcast_mask_b
),
tBgB
(
_
,
_
,
_
,
*
k_tile_iter
),
tBsB
(
_
,
_
,
_
,
write_stage
));
copy
(
mainloop_params
.
tma_load_aux
.
with
(
*
tma_barrier
,
mcast_mask_aux
),
tAuxgAux
(
_
,
_
,
_
,
*
k_tile_iter
),
tAuxsAux
(
_
,
_
,
_
,
write_stage
));
++
k_tile_iter
;
// Advance smem_pipe_write
++
smem_pipe_write
;
}
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
CUTLASS_DEVICE
void
load_tail
(
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_write
)
{
int
lane_predicate
=
cute
::
elect_one_sync
();
// Issue the epilogue waits
if
(
lane_predicate
)
{
/* This helps avoid early exit of blocks in Cluster
* Waits for all stages to either be released (all
* Consumer UNLOCKs), or if the stage was never used
* then would just be acquired since the phase was
* still inverted from make_producer_start_state
*/
pipeline
.
producer_tail
(
smem_pipe_write
);
}
}
/// Perform a collective-scoped matrix multiply-accumulate
/// Consumer Perspective
template
<
class
FrgTensorC
>
CUTLASS_DEVICE
void
mma
(
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_read
,
FrgTensorC
&
accum0
,
FrgTensorC
&
accum1
,
int
k_tile_count
,
int
thread_idx
,
TensorStorage
&
shared_tensors
,
Params
const
&
mainloop_params
)
{
static_assert
(
is_rmem
<
FrgTensorC
>::
value
,
"C tensor must be rmem resident."
);
static_assert
(
cute
::
rank
(
SmemLayoutA
{})
==
3
,
"Smem layout must be rank 3."
);
static_assert
(
cute
::
rank
(
SmemLayoutB
{})
==
3
,
"Smem layout must be rank 3."
);
static_assert
(
cute
::
is_void_v
<
SmemCopyAtomA
>
,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."
);
static_assert
(
cute
::
is_void_v
<
SmemCopyAtomB
>
,
"SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."
);
Tensor
sA
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_A
.
data
()),
SmemLayoutA
{});
// (BLK_M,BLK_K,PIPE)
Tensor
sB
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_B
.
data
()),
SmemLayoutB
{});
// (BLK_N,BLK_K,PIPE)
Tensor
sAux
=
make_tensor
(
make_smem_ptr
(
shared_tensors
.
smem_Aux
.
data
()),
SmemLayoutAux
{});
//
// Define C accumulators and A/B partitioning
//
TiledMma
tiled_mma
;
auto
thread_mma
=
tiled_mma
.
get_thread_slice
(
thread_idx
);
Tensor
tCsA
=
thread_mma
.
partition_A
(
sA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCsB
=
thread_mma
.
partition_B
(
sB
);
// (MMA,MMA_N,MMA_K,PIPE)
// Allocate "fragments/descriptors"
Tensor
tCrA
=
thread_mma
.
make_fragment_A
(
tCsA
);
// (MMA,MMA_M,MMA_K,PIPE)
Tensor
tCrB
=
thread_mma
.
make_fragment_B
(
tCsB
);
// (MMA,MMA_N,MMA_K,PIPE)
auto
tCsAux
=
[
&
]()
->
auto
{
if
constexpr
(
SwapAB
)
{
return
thread_mma
.
partition_A
(
sAux
);
}
else
{
return
thread_mma
.
partition_B
(
sAux
);
}
}();
auto
tCrAux
=
[
&
]()
->
auto
{
if
constexpr
(
SwapAB
)
{
return
thread_mma
.
make_fragment_A
(
tCsAux
);
}
else
{
return
thread_mma
.
make_fragment_B
(
tCsAux
);
}
}();
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
accum0
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
2
>
(
accum0
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCsA
)
==
size
<
2
>
(
tCsB
));
// K
CUTE_STATIC_ASSERT_V
(
size
<
3
>
(
tCsA
)
==
size
<
3
>
(
tCsB
));
// PIPE
if
constexpr
(
SwapAB
)
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsAux
)
==
size
<
1
>
(
accum1
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsB
)
==
size
<
2
>
(
accum1
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCsB
)
==
size
<
2
>
(
tCsAux
));
// K
CUTE_STATIC_ASSERT_V
(
size
<
3
>
(
tCsB
)
==
size
<
3
>
(
tCsAux
));
// PIPE
}
else
{
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsA
)
==
size
<
1
>
(
accum1
));
// M
CUTE_STATIC_ASSERT_V
(
size
<
1
>
(
tCsAux
)
==
size
<
2
>
(
accum1
));
// N
CUTE_STATIC_ASSERT_V
(
size
<
2
>
(
tCsA
)
==
size
<
2
>
(
tCsAux
));
// K
CUTE_STATIC_ASSERT_V
(
size
<
3
>
(
tCsA
)
==
size
<
3
>
(
tCsAux
));
// PIPE
}
CUTE_STATIC_ASSERT_V
(
Int
<
DispatchPolicy
::
Stages
>
{}
==
size
<
2
>
(
sA
));
// PIPE
CUTE_STATIC_ASSERT_V
(
Int
<
DispatchPolicy
::
Stages
>
{}
==
size
<
2
>
(
sB
));
// PIPE
CUTE_STATIC_ASSERT_V
(
Int
<
DispatchPolicy
::
Stages
>
{}
==
size
<
2
>
(
sAux
));
// PIPE
//
// PIPELINED MAIN LOOP
//
static_assert
((
0
<=
K_PIPE_MMAS
)
&&
(
K_PIPE_MMAS
<
K_PIPE_MAX
),
"ERROR : Incorrect number of MMAs in flight"
);
// We release buffers to producer warps(dma load) with some mmas in flight
PipelineState
smem_pipe_release
=
smem_pipe_read
;
// Prologue GMMAs
int
prologue_mma_count
=
min
(
K_PIPE_MMAS
,
k_tile_count
);
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
Zero
;
GmmaFP8Accumulation
accumulation0
(
accum0
,
mainloop_params
.
mma_promotion_interval
,
size
<
2
>
(
tCrA
));
GmmaFP8Accumulation
accumulation1
(
accum1
,
mainloop_params
.
mma_promotion_interval
,
size
<
2
>
(
tCrA
));
warpgroup_fence_operand
(
accumulation0
());
warpgroup_fence_operand
(
accumulation1
());
CUTLASS_PRAGMA_UNROLL
for
(
int
k_tile_prologue
=
prologue_mma_count
;
k_tile_prologue
>
0
;
--
k_tile_prologue
)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto
barrier_token
=
pipeline
.
consumer_try_wait
(
smem_pipe_read
);
pipeline
.
consumer_wait
(
smem_pipe_read
,
barrier_token
);
if
(
accumulation0
.
prepare_if_needed
())
{
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
Zero
;
}
int
read_stage
=
smem_pipe_read
.
index
();
warpgroup_arrive
();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
);
++
k_block
)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accumulation0
());
if
constexpr
(
SwapAB
)
{
cute
::
gemm
(
tiled_mma
,
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accumulation1
());
}
else
{
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
accumulation1
());
}
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
One
;
}
warpgroup_commit_batch
();
accumulation0
.
promote_if_needed
();
accumulation1
.
promote_if_needed
();
++
smem_pipe_read
;
}
warpgroup_fence_operand
(
accumulation0
());
warpgroup_fence_operand
(
accumulation1
());
// Mainloop GMMAs
k_tile_count
-=
prologue_mma_count
;
CUTLASS_PRAGMA_NO_UNROLL
for
(;
k_tile_count
>
0
;
--
k_tile_count
)
{
// WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value)
auto
barrier_token
=
pipeline
.
consumer_try_wait
(
smem_pipe_read
);
pipeline
.
consumer_wait
(
smem_pipe_read
,
barrier_token
);
//
// Compute on k_tile
//
int
read_stage
=
smem_pipe_read
.
index
();
if
(
accumulation0
.
prepare_if_needed
())
{
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
Zero
;
}
warpgroup_fence_operand
(
accumulation0
());
warpgroup_fence_operand
(
accumulation1
());
warpgroup_arrive
();
// Unroll the K mode manually to set scale D to 1
CUTLASS_PRAGMA_UNROLL
for
(
int
k_block
=
0
;
k_block
<
size
<
2
>
(
tCrA
);
++
k_block
)
{
// (V,M,K) x (V,N,K) => (V,M,N)
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accumulation0
());
if
constexpr
(
SwapAB
)
{
cute
::
gemm
(
tiled_mma
,
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
tCrB
(
_
,
_
,
k_block
,
read_stage
),
accumulation1
());
}
else
{
cute
::
gemm
(
tiled_mma
,
tCrA
(
_
,
_
,
k_block
,
read_stage
),
tCrAux
(
_
,
_
,
k_block
,
read_stage
),
accumulation1
());
}
tiled_mma
.
accumulate_
=
GMMA
::
ScaleOut
::
One
;
}
warpgroup_commit_batch
();
/// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed
warpgroup_wait
<
K_PIPE_MMAS
>
();
warpgroup_fence_operand
(
accumulation0
());
warpgroup_fence_operand
(
accumulation1
());
accumulation0
.
promote_if_needed
();
accumulation1
.
promote_if_needed
();
pipeline
.
consumer_release
(
smem_pipe_release
);
// UNLOCK smem_pipe_release, done _computing_ on it
// Advance smem_pipe_read and smem_pipe_release
++
smem_pipe_read
;
++
smem_pipe_release
;
}
accumulation0
.
promote_residue_if_needed
();
accumulation1
.
promote_residue_if_needed
();
warpgroup_fence_operand
(
accumulation0
());
warpgroup_fence_operand
(
accumulation1
());
}
/// Perform a Consumer Epilogue to release all buffers
CUTLASS_DEVICE
void
mma_tail
(
MainloopPipeline
pipeline
,
PipelineState
smem_pipe_release
,
int
k_tile_count
)
{
// Prologue GMMAs
int
prologue_mma_count
=
min
(
K_PIPE_MMAS
,
k_tile_count
);
k_tile_count
-=
prologue_mma_count
;
smem_pipe_release
.
advance
(
k_tile_count
);
// Wait on all GMMAs to complete
warpgroup_wait
<
0
>
();
for
(
int
count
=
0
;
count
<
prologue_mma_count
;
++
count
)
{
pipeline
.
consumer_release
(
smem_pipe_release
);
// UNLOCK smem_pipe_release, done _computing_ on it
++
smem_pipe_release
;
}
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace cutlass::gemm::collective
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/gemm_universal_base_compat.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief The universal GEMM accommodates serial reductions, parallel reductions, batched strided, and
batched array variants.
*/
#pragma once
// #include <limits>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
/*
This is the device layer from CUTLASS 2.10 (SHA - cc85b64cf676c45f98a17e3a47c0aafcf817f088)
It is replicated here since we needed to duplicate kernel level APIs for mixed dtype GEMMs
and SmoothQuant. The newer device layer is not compatible with these older kernel level APIs.
Note: While CUTLASS 3.x supports stream-k, none of the kernels in the extensions folder support
that feature at the moment.
*/
template
<
typename
GemmKernel_
>
class
GemmUniversalBaseCompat
{
public:
using
GemmKernel
=
GemmKernel_
;
using
ThreadblockShape
=
typename
GemmKernel
::
Mma
::
Shape
;
using
ElementA
=
typename
GemmKernel
::
ElementA
;
using
LayoutA
=
typename
GemmKernel
::
LayoutA
;
using
TensorRefA
=
TensorRef
<
ElementA
const
,
LayoutA
>
;
static
ComplexTransform
const
kTransformA
=
GemmKernel
::
kTransformA
;
using
ElementB
=
typename
GemmKernel
::
ElementB
;
using
LayoutB
=
typename
GemmKernel
::
LayoutB
;
using
TensorRefB
=
TensorRef
<
ElementB
const
,
LayoutB
>
;
static
ComplexTransform
const
kTransformB
=
GemmKernel
::
kTransformB
;
using
ElementC
=
typename
GemmKernel
::
ElementC
;
using
LayoutC
=
typename
GemmKernel
::
LayoutC
;
using
TensorRefC
=
TensorRef
<
ElementC
const
,
LayoutC
>
;
using
TensorRefD
=
TensorRef
<
ElementC
,
LayoutC
>
;
using
ElementAccumulator
=
typename
GemmKernel
::
Mma
::
Policy
::
Operator
::
ElementC
;
using
EpilogueOutputOp
=
typename
GemmKernel
::
EpilogueOutputOp
;
using
ThreadblockSwizzle
=
typename
GemmKernel
::
ThreadblockSwizzle
;
using
Operator
=
typename
GemmKernel
::
Operator
;
/// Argument structure
using
Arguments
=
typename
GemmKernel
::
Arguments
;
protected:
/// Kernel parameters object
typename
GemmKernel
::
Params
params_
;
protected:
/// Private helper to obtain the grid dimensions with fix-up for split-K
static
void
get_grid_shape_
(
gemm
::
GemmCoord
&
grid_tiled_shape
,
int
&
gemm_k_size
,
Arguments
const
&
args
)
{
// Determine grid shape
ThreadblockSwizzle
threadblock_swizzle
;
grid_tiled_shape
=
threadblock_swizzle
.
get_tiled_shape
(
args
.
problem_size
,
{
ThreadblockShape
::
kM
,
ThreadblockShape
::
kN
,
ThreadblockShape
::
kK
},
args
.
batch_count
);
gemm_k_size
=
args
.
problem_size
.
k
();
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
||
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
int
const
kAlignK
=
const_max
(
const_max
(
128
/
sizeof_bits
<
ElementA
>::
value
,
128
/
sizeof_bits
<
ElementB
>::
value
),
1
);
gemm_k_size
=
round_up
(
ceil_div
(
args
.
problem_size
.
k
(),
args
.
batch_count
),
kAlignK
);
if
(
gemm_k_size
)
{
grid_tiled_shape
.
k
()
=
ceil_div
(
args
.
problem_size
.
k
(),
gemm_k_size
);
}
}
}
public:
/// Constructs the GEMM.
GemmUniversalBaseCompat
()
{}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
uint32_t
const
kGridYZMax
=
((
1
<<
(
sizeof
(
uint16_t
)
*
8
))
-
1
);
if
(
!
(
grid
.
y
<=
kGridYZMax
&&
grid
.
z
<=
kGridYZMax
))
{
return
Status
::
kErrorInvalidProblem
;
}
return
GemmKernel
::
can_implement
(
args
);
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_workspace_size()"
);
size_t
workspace_bytes
=
0
;
// Determine grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
if
(
args
.
mode
==
GemmUniversalMode
::
kGemmSplitKParallel
)
{
// Split-K parallel always requires a temporary workspace
workspace_bytes
=
sizeof
(
ElementC
)
*
size_t
(
args
.
batch_stride_D
)
*
size_t
(
grid_tiled_shape
.
k
());
}
else
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
&&
grid_tiled_shape
.
k
()
>
1
)
{
// Serial split-K only requires a temporary workspace if the number of partitions along the
// GEMM K dimension is greater than one.
workspace_bytes
=
sizeof
(
int
)
*
size_t
(
grid_tiled_shape
.
m
())
*
size_t
(
grid_tiled_shape
.
n
());
}
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
workspace_bytes
+=
GemmKernel
::
get_extra_workspace_size
(
args
,
grid_tiled_shape
);
return
workspace_bytes
;
}
/// Computes the grid shape
static
dim3
get_grid_shape
(
Arguments
const
&
args
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::get_grid_shape()"
);
ThreadblockSwizzle
threadblock_swizzle
;
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
dim3
result
=
threadblock_swizzle
.
get_grid_shape
(
grid_tiled_shape
);
CUTLASS_TRACE_HOST
(
" grid_tiled_shape: "
<<
grid_tiled_shape
<<
"
\n
"
<<
" result = {"
<<
result
<<
"}"
);
return
result
;
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
smem_capacity
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::maximum_active_blocks()"
);
int
max_active_blocks
=
-
1
;
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
if
(
smem_size
<=
(
48
<<
10
))
{
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
smem_size
);
if
(
result
==
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
}
else
{
// Query assuming zero shared memory then compute occupancy limit based on SMEM
cudaError_t
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
GemmKernel
>
,
GemmKernel
::
kThreadCount
,
0
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
if
(
smem_capacity
<
0
)
{
int
device_idx
=
0
;
result
=
cudaGetDevice
(
&
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
cudaDeviceProp
properties
;
result
=
cudaGetDeviceProperties
(
&
properties
,
device_idx
);
if
(
result
!=
cudaSuccess
)
{
return
-
1
;
}
smem_capacity
=
static_cast
<
int
>
(
properties
.
sharedMemPerMultiprocessor
);
}
int
occupancy
=
std
::
min
(
max_active_blocks
,
smem_capacity
/
smem_size
);
CUTLASS_TRACE_HOST
(
" occupancy: "
<<
occupancy
);
return
occupancy
;
}
CUTLASS_TRACE_HOST
(
" returning internal error"
);
return
-
1
;
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
size_t
workspace_bytes
=
get_workspace_size
(
args
);
CUTLASS_TRACE_HOST
(
" workspace_bytes: "
<<
workspace_bytes
);
if
(
workspace_bytes
)
{
if
(
!
workspace
)
{
CUTLASS_TRACE_HOST
(
" error: device workspace must not be null"
);
return
Status
::
kErrorWorkspaceNull
;
}
if
(
args
.
mode
==
GemmUniversalMode
::
kGemm
)
{
CUTLASS_TRACE_HOST
(
" clearing device workspace"
);
cudaError_t
result
=
cudaMemsetAsync
(
workspace
,
0
,
workspace_bytes
,
stream
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaMemsetAsync() returned error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
}
// Get CUDA grid shape
cutlass
::
gemm
::
GemmCoord
grid_tiled_shape
;
int
gemm_k_size
=
0
;
get_grid_shape_
(
grid_tiled_shape
,
gemm_k_size
,
args
);
// Initialize the Params structure
params_
=
typename
GemmKernel
::
Params
(
args
,
grid_tiled_shape
,
gemm_k_size
,
static_cast
<
int
*>
(
workspace
));
// Specify shared memory capacity for kernel.
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
if
(
smem_size
>=
(
48
<<
10
))
{
cudaError_t
result
=
cudaFuncSetAttribute
(
Kernel
<
GemmKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
return
Status
::
kErrorInternal
;
}
}
return
Status
::
kSuccess
;
}
/// Lightweight update given a subset of arguments
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat()::update() - workspace: "
<<
workspace
);
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
&&
!
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
params_
.
update
(
args
,
workspace
);
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"GemmUniversalBaseCompat::run()"
);
//
// Configure grid and block dimensions
//
ThreadblockSwizzle
threadblock_swizzle
;
dim3
grid
=
threadblock_swizzle
.
get_grid_shape
(
params_
.
grid_tiled_shape
);
dim3
block
(
GemmKernel
::
kThreadCount
,
1
,
1
);
int
smem_size
=
int
(
sizeof
(
typename
GemmKernel
::
SharedStorage
));
//
// Launch kernel
//
CUTLASS_TRACE_HOST
(
" grid: ("
<<
grid
<<
"), block: ("
<<
block
<<
"), SMEM: "
<<
smem_size
<<
" bytes"
);
// Launch
cutlass
::
Kernel
<
GemmKernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
params_
);
//
// Query for errors
//
cudaError_t
result
=
cudaGetLastError
();
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" grid launch failed with error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
stream
);
}
/// Runs the kernel using initialized state.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
status
==
Status
::
kSuccess
)
{
status
=
run
(
stream
);
}
return
status
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace device
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/device/splitk_gemm_grouped.h
deleted
100644 → 0
View file @
9829e77e
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*!
\file
\brief Based on cutlass/include/cutlass/gemm/kernel/gemm_grouped.h
*/
#pragma once
#include <limits>
#include <numeric>
#include <vector>
#include "cutlass/arch/arch.h"
#include "cutlass/cutlass.h"
#include "cutlass/device_kernel.h"
#include "cutlass/numeric_types.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/gemm/kernel/gemm_universal.h"
#include "cutlass/gemm/threadblock/threadblock_swizzle.h"
#include "cutlass/gemm/device/default_gemm_configuration.h"
#include "cutlass/gemm/kernel/default_gemm_universal.h"
#include "cutlass/trace.h"
////////////////////////////////////////////////////////////////////////////////
namespace
cutlass
{
namespace
gemm
{
namespace
device
{
/////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
T_IN
,
typename
T_OUT
>
__global__
void
splitkReduction
(
T_OUT
**
out_tensor
,
const
T_IN
*
in_tensor
,
GemmCoord
const
*
problem_sizes
,
int
splitk
,
int64_t
*
splitk_buffer_offsets
)
{
// in_tensor: [problem_idx, k_partition, hidden_size]
// Note that different requests of in_tensor might have different hidden_size (=m*n)
// so, we need to use splitk_buffer_offsets.
// out_tensor: problem_idx * [hidden_size]
int
const
problem_idx
=
blockIdx
.
y
;
GemmCoord
problem
=
problem_sizes
[
problem_idx
];
int
const
hidden_size
=
problem
.
m
()
*
problem
.
n
();
const
T_IN
*
in_tensor_
=
in_tensor
+
splitk_buffer_offsets
[
problem_idx
]
*
splitk
;
T_OUT
*
out_tensor_
=
out_tensor
[
problem_idx
];
for
(
int
i
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
i
<
hidden_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
float
sum
=
0.0
f
;
for
(
int
k_idx
=
0
;
k_idx
<
splitk
;
k_idx
++
)
{
sum
+=
(
float
)
in_tensor_
[
k_idx
*
hidden_size
+
i
];
}
out_tensor_
[
i
]
=
(
T_OUT
)
(
sum
);
}
}
/// GEMM Grouped
template
<
typename
BaseKernel_
>
class
BaseSplitkGrouped
{
public:
using
BaseKernel
=
BaseKernel_
;
using
ElementA
=
typename
BaseKernel
::
ElementA
;
using
LayoutA
=
typename
BaseKernel
::
LayoutA
;
using
TensorRefA
=
TensorRef
<
ElementA
const
,
LayoutA
>
;
static
ComplexTransform
const
kTransformA
=
BaseKernel
::
kTransformA
;
static
int
const
kAlignmentA
=
BaseKernel
::
kAlignmentA
;
using
ElementB
=
typename
BaseKernel
::
ElementB
;
using
LayoutB
=
typename
BaseKernel
::
LayoutB
;
using
TensorRefB
=
TensorRef
<
ElementB
const
,
LayoutB
>
;
static
ComplexTransform
const
kTransformB
=
BaseKernel
::
kTransformB
;
static
int
const
kAlignmentB
=
BaseKernel
::
kAlignmentB
;
using
ElementC
=
typename
BaseKernel
::
ElementC
;
using
LayoutC
=
typename
BaseKernel
::
LayoutC
;
using
TensorRefC
=
TensorRef
<
ElementC
const
,
LayoutC
>
;
using
TensorRefD
=
TensorRef
<
ElementC
,
LayoutC
>
;
static
int
const
kAlignmentC
=
BaseKernel
::
kAlignmentC
;
using
ElementAccumulator
=
typename
BaseKernel
::
Mma
::
Policy
::
Operator
::
ElementC
;
using
EpilogueOutputOp
=
typename
BaseKernel
::
EpilogueOutputOp
;
using
ThreadblockSwizzle
=
typename
threadblock
::
GemmSplitKHorizontalThreadblockSwizzle
;
using
Operator
=
typename
BaseKernel
::
Operator
;
using
WarpMmaOperator
=
typename
BaseKernel
::
Mma
::
Policy
::
Operator
;
using
ArchMmaOperator
=
typename
WarpMmaOperator
::
ArchMmaOperator
;
using
MathOperator
=
typename
WarpMmaOperator
::
MathOperator
;
using
OperatorClass
=
typename
WarpMmaOperator
::
OperatorClass
;
using
ArchTag
=
typename
WarpMmaOperator
::
ArchTag
;
using
ThreadblockShape
=
typename
BaseKernel
::
Mma
::
Shape
;
using
WarpShape
=
typename
BaseKernel
::
WarpShape
;
using
InstructionShape
=
typename
BaseKernel
::
InstructionShape
;
static
int
const
kStages
=
BaseKernel
::
Mma
::
kStages
;
/// Argument structure
using
Arguments
=
typename
BaseKernel
::
Arguments
;
using
ProblemInfo
=
typename
BaseKernel
::
ProblemVisitor
::
ProblemInfo
;
protected:
/// Kernel parameters object
typename
BaseKernel
::
Params
gemm_params_
;
private:
/// Get the number of tiles across all problems in a group
static
int32_t
group_tile_count
(
cutlass
::
gemm
::
GemmCoord
const
*
problem_sizes_ptr
,
int
problem_count
)
{
int32_t
tiles
=
0
;
for
(
int32_t
i
=
0
;
i
<
problem_count
;
++
i
)
{
cutlass
::
gemm
::
GemmCoord
problem
=
problem_sizes_ptr
[
i
];
BaseKernel
::
ProblemVisitor
::
possibly_transpose_problem
(
problem
);
tiles
+=
problem_tile_count
(
problem
);
}
return
tiles
;
}
/// Copy from `data` to `workspace`
Status
copy_to_workspace
(
void
*
workspace
,
void
*
data
,
size_t
bytes
)
{
cudaError_t
cuda_error
=
cudaMemcpy
(
workspace
,
data
,
bytes
,
cudaMemcpyHostToDevice
);
if
(
cuda_error
!=
cudaSuccess
)
{
// Call cudaGetLastError() to clear the error bit
cuda_error
=
cudaGetLastError
();
CUTLASS_TRACE_HOST
(
" cudaMemcpy() returned error "
<<
cudaGetErrorString
(
cuda_error
));
return
Status
::
kErrorInternal
;
}
return
Status
::
kSuccess
;
}
/// Precomputes scheduling information for the grouped GEMM
Status
precompute
(
Arguments
const
&
args
,
int32_t
tile_count
,
void
*
workspace
)
{
size_t
workspace_bytes
=
get_workspace_size
(
args
);
std
::
vector
<
uint8_t
>
host_workspace
(
workspace_bytes
);
BaseKernel
::
ProblemVisitor
::
host_precompute
(
args
.
host_problem_sizes
,
args
.
problem_count
,
args
.
threadblock_count
,
(
void
*
)
host_workspace
.
data
());
return
copy_to_workspace
(
workspace
,
host_workspace
.
data
(),
workspace_bytes
);
}
/// Reorder `data` according to `indices`
template
<
typename
T
>
static
void
reorder_array
(
T
*
data
,
std
::
vector
<
size_t
>
const
&
indices
)
{
// For now, simply create a copy of the data and then copy over to the original.
std
::
vector
<
T
>
copy
(
indices
.
size
());
for
(
size_t
i
=
0
;
i
<
indices
.
size
();
++
i
)
{
copy
.
at
(
i
)
=
data
[
indices
[
i
]];
}
memcpy
(
data
,
copy
.
data
(),
indices
.
size
()
*
sizeof
(
T
));
}
public:
/// Constructs the GEMM.
BaseSplitkGrouped
()
{}
/// Determines whether the GEMM can execute the given problem.
static
Status
can_implement
(
Arguments
const
&
args
)
{
return
BaseKernel
::
can_implement
(
args
);
}
/// Get the number of tiles in a problem
static
int32_t
problem_tile_count
(
cutlass
::
gemm
::
GemmCoord
const
&
problem
)
{
auto
grid
=
BaseKernel
::
ProblemVisitor
::
grid_shape
(
problem
);
return
BaseKernel
::
ProblemVisitor
::
tile_count
(
grid
);
}
/// Get the number of tiles across all problems in a group
static
int32_t
group_tile_count
(
Arguments
const
&
args
)
{
if
(
args
.
host_problem_sizes
==
nullptr
)
{
CUTLASS_TRACE_HOST
(
"Received nullptr for `args.host_problem_sizes"
);
return
-
1
;
}
return
group_tile_count
(
args
.
host_problem_sizes
,
args
.
problem_count
);
}
/// Gets the workspace size
static
size_t
get_workspace_size
(
Arguments
const
&
args
)
{
size_t
total_mn
=
0
;
for
(
int
i
=
0
;
i
<
args
.
problem_count
;
i
++
)
{
total_mn
+=
args
.
host_problem_sizes
[
i
].
m
()
*
args
.
host_problem_sizes
[
i
].
n
();
}
size_t
workSpaceSize
=
total_mn
*
sizeof
(
ElementAccumulator
)
*
args
.
split_k_slices
;
if
(
BaseKernel
::
ProblemVisitor
::
kRequiresPrecomputation
)
{
workSpaceSize
+=
BaseKernel
::
ProblemVisitor
::
get_workspace_size
(
args
.
host_problem_sizes
,
args
.
problem_count
,
args
.
threadblock_count
);
}
return
workSpaceSize
;
}
/// Computes the grid shape
static
dim3
get_grid_shape
(
Arguments
const
&
args
)
{
return
dim3
(
args
.
threadblock_count
,
1
,
1
);
}
/// Computes the maximum number of active blocks per multiprocessor
static
int
maximum_active_blocks
(
int
smem_capacity
=
-
1
)
{
CUTLASS_TRACE_HOST
(
"BaseSplitkGrouped::maximum_active_blocks()"
);
int
smem_size
=
int
(
sizeof
(
typename
BaseKernel
::
SharedStorage
));
CUTLASS_TRACE_HOST
(
" smem_size: "
<<
smem_size
<<
" bytes"
);
cudaError_t
result
;
if
(
smem_size
>
(
48
<<
10
))
{
result
=
cudaFuncSetAttribute
(
Kernel
<
BaseKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
// Call cudaGetLastError() to clear the error bit
result
=
cudaGetLastError
();
CUTLASS_TRACE_HOST
(
" cudaFuncSetAttribute() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
}
int
max_active_blocks
=
-
1
;
result
=
cudaOccupancyMaxActiveBlocksPerMultiprocessor
(
&
max_active_blocks
,
Kernel
<
BaseKernel
>
,
BaseKernel
::
kThreadCount
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
// Call cudaGetLastError() to clear the error bit
result
=
cudaGetLastError
();
CUTLASS_TRACE_HOST
(
" cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error "
<<
cudaGetErrorString
(
result
));
return
-
1
;
}
CUTLASS_TRACE_HOST
(
" max_active_blocks: "
<<
max_active_blocks
);
return
max_active_blocks
;
}
/// Sorts each pointer passed in according to the indices that sort
/// `problem_sizes_ptr` in descending order of problem-K dimension.
static
void
sort_problems
(
int
problem_count
,
cutlass
::
gemm
::
GemmCoord
*
problem_sizes_ptr
,
int64_t
*
lda_host_ptr
,
int64_t
*
ldb_host_ptr
,
int64_t
*
ldc_host_ptr
,
int64_t
*
ldd_host_ptr
,
int64_t
*
offset_A_ptr
,
int64_t
*
offset_B_ptr
,
int64_t
*
offset_C_ptr
,
int64_t
*
offset_D_ptr
)
{
std
::
vector
<
size_t
>
indices
(
problem_count
);
std
::
iota
(
indices
.
begin
(),
indices
.
end
(),
0
);
std
::
stable_sort
(
indices
.
begin
(),
indices
.
end
(),
[
&
problem_sizes_ptr
](
size_t
i
,
size_t
j
)
{
return
problem_sizes_ptr
[
i
].
k
()
>
problem_sizes_ptr
[
j
].
k
();
});
reorder_array
(
problem_sizes_ptr
,
indices
);
reorder_array
(
lda_host_ptr
,
indices
);
reorder_array
(
ldb_host_ptr
,
indices
);
reorder_array
(
ldc_host_ptr
,
indices
);
reorder_array
(
ldd_host_ptr
,
indices
);
reorder_array
(
offset_A_ptr
,
indices
);
reorder_array
(
offset_B_ptr
,
indices
);
reorder_array
(
offset_C_ptr
,
indices
);
reorder_array
(
offset_D_ptr
,
indices
);
}
/// Computes the number of threadblocks to launch for the grouped kernel
static
int
sufficient
(
cutlass
::
gemm
::
GemmCoord
const
*
problem_sizes_ptr
=
nullptr
,
int
problem_count
=
0
,
int
available_sm_count
=
-
1
)
{
// Determine the number of blocks that would be launched to fill up a single
// wave on the GPU with each SM having maximum occupancy.
int
device_idx
;
cudaError_t
result
=
cudaGetDevice
(
&
device_idx
);
if
(
result
!=
cudaSuccess
)
{
// Call cudaGetLastError() to clear the error bit
result
=
cudaGetLastError
();
CUTLASS_TRACE_HOST
(
" cudaGetDevice() returned error "
<<
cudaGetErrorString
(
result
));
return
0
;
}
int
multiprocessor_count
;
result
=
cudaDeviceGetAttribute
(
&
multiprocessor_count
,
cudaDevAttrMultiProcessorCount
,
device_idx
);
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" cudaDeviceGetAttribute() returned error "
<<
cudaGetErrorString
(
result
));
return
0
;
}
bool
override_sm_count
=
(
available_sm_count
<
0
||
available_sm_count
>
multiprocessor_count
);
if
(
override_sm_count
)
{
available_sm_count
=
multiprocessor_count
;
}
int
max_active_blocks
=
maximum_active_blocks
();
if
(
max_active_blocks
<=
0
)
{
return
0
;
}
int
occupancy_based_block_count
=
available_sm_count
*
max_active_blocks
;
if
(
problem_sizes_ptr
==
nullptr
||
problem_count
==
0
)
{
return
occupancy_based_block_count
;
}
int
total_tiles
=
group_tile_count
(
problem_sizes_ptr
,
problem_count
);
// If the group contains a single problem, launching the exact number of
// threadblocks needed to cover the problem minimizes the work performed
// per threadblock in finding the next tile to compute. We return total_tiles
// unless the user has provided the SM count.
if
(
problem_count
==
1
&&
override_sm_count
)
{
return
total_tiles
;
}
// Choose between the full wave of threadblocks and the tile count. If there
// are fewer tiles in the group than threadblocks in the full wave, only
// some threadblocks will be assigned tiles. Those threadblocks
// which are not assigned tiles still need to perform the work of iterating through
// problem sizes to determine that they have no work to do. This competes for cycles
// with those threadblocks that are assigned tiles to compute.
return
std
::
min
(
total_tiles
,
occupancy_based_block_count
);
}
/// Initializes GEMM state from arguments.
Status
initialize
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
,
cudaStream_t
stream
=
nullptr
)
{
CUTLASS_TRACE_HOST
(
"BaseSplitkGrouped::initialize() - workspace "
<<
workspace
<<
", stream: "
<<
(
stream
?
"non-null"
:
"null"
));
// Workspace
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
&&
!
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
if
(
BaseKernel
::
ProblemVisitor
::
kRequiresPrecomputation
)
{
int32_t
tile_count
=
group_tile_count
(
args
);
Status
status
=
precompute
(
args
,
tile_count
,
workspace
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
gemm_params_
=
typename
BaseKernel
::
Params
(
args
,
workspace
,
tile_count
);
}
else
{
gemm_params_
=
typename
BaseKernel
::
Params
(
args
,
workspace
);
}
// Specify shared memory capacity for kernel.
int
smem_size
=
int
(
sizeof
(
typename
BaseKernel
::
SharedStorage
));
if
(
smem_size
>=
(
48
<<
10
))
{
cudaError_t
result
=
cudaFuncSetAttribute
(
Kernel
<
BaseKernel
>
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
);
if
(
result
!=
cudaSuccess
)
{
return
Status
::
kErrorInternal
;
}
}
return
Status
::
kSuccess
;
}
/// Lightweight update given a subset of arguments
Status
update
(
Arguments
const
&
args
,
void
*
workspace
=
nullptr
)
{
size_t
workspace_bytes
=
get_workspace_size
(
args
);
if
(
workspace_bytes
&&
!
workspace
)
{
return
Status
::
kErrorWorkspaceNull
;
}
if
(
BaseKernel
::
ProblemVisitor
::
kRequiresPrecomputation
)
{
int32_t
tile_count
=
group_tile_count
(
args
);
Status
status
=
precompute
(
args
,
tile_count
,
workspace
);
if
(
status
!=
Status
::
kSuccess
)
{
return
status
;
}
gemm_params_
.
update
(
args
,
workspace
,
tile_count
);
}
else
{
gemm_params_
.
update
(
args
,
workspace
);
}
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
run
(
cudaStream_t
stream
=
nullptr
)
{
if
(
!
gemm_params_
.
problem_visitor
.
problem_count
)
{
return
Status
::
kSuccess
;
}
//
// Launch kernel
//
// Launch splitk grouped gemm
{
dim3
grid
(
gemm_params_
.
threadblock_count
,
1
,
gemm_params_
.
split_k_slices
);
dim3
block
(
BaseKernel
::
kThreadCount
,
1
,
1
);
int
smem_size
=
int
(
sizeof
(
typename
BaseKernel
::
SharedStorage
));
cutlass
::
Kernel
<
BaseKernel
><<<
grid
,
block
,
smem_size
,
stream
>>>
(
gemm_params_
);
cudaError_t
result
=
cudaGetLastError
();
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" grid launch failed with error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
// Launch splitkReduction
{
dim3
grid
(
32
,
gemm_params_
.
problem_visitor
.
problem_count
);
dim3
block
(
256
);
splitkReduction
<<<
grid
,
block
,
0
,
stream
>>>
(
gemm_params_
.
ptr_D
,
gemm_params_
.
ptr_D_split
,
gemm_params_
.
problem_visitor
.
problem_sizes
,
gemm_params_
.
split_k_slices
,
gemm_params_
.
splitk_buffer_offsets
);
cudaError_t
result
=
cudaGetLastError
();
if
(
result
!=
cudaSuccess
)
{
CUTLASS_TRACE_HOST
(
" grid launch failed with error "
<<
cudaGetErrorString
(
result
));
return
Status
::
kErrorInternal
;
}
}
return
Status
::
kSuccess
;
}
/// Runs the kernel using initialized state.
Status
operator
()(
cudaStream_t
stream
=
nullptr
)
{
return
run
(
stream
);
}
/// Initializes and runs the kernel.
Status
operator
()(
Arguments
const
&
args
,
void
*
workspace
,
cudaStream_t
stream
=
nullptr
)
{
Status
status
=
initialize
(
args
,
workspace
,
stream
);
if
(
status
==
Status
::
kSuccess
)
{
status
=
run
(
stream
);
}
return
status
;
}
};
/////////////////////////////////////////////////////////////////////////////////////////////////
/// GEMM Grouped
template
<
typename
GemmKernel_
>
class
SplitkGemmGrouped
:
public
BaseSplitkGrouped
<
GemmKernel_
>
{
public:
using
GemmKernel
=
GemmKernel_
;
};
/////////////////////////////////////////////////////////////////////////////////////////////////
}
// namespace device
}
// namespace gemm
}
// namespace cutlass
/////////////////////////////////////////////////////////////////////////////////////////////////
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_fpA_intB_traits.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/bfloat16.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/half.h"
#include "cutlass/layout/matrix.h"
#include "cutlass_extensions/arch/mma.h"
#include "cutlass_extensions/gemm/kernel/mixed_gemm_B_layout.h"
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
template
<
typename
TypeA
,
typename
TypeB
,
typename
arch
,
typename
Enable
=
void
>
struct
MixedGemmArchTraits
{
static_assert
(
dependent_false
<
arch
>
,
"Unrecognised parameterization"
);
};
template
<
typename
Arch
>
struct
MixedGemmArchTraits
<
float
,
float
,
Arch
>
{
static
constexpr
int
Stages
=
2
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassSimt
;
using
AccType
=
float
;
using
LayoutB
=
cutlass
::
layout
::
ColumnMajor
;
static
constexpr
int
ElementsPerAccessA
=
1
;
static
constexpr
int
ElementsPerAccessB
=
1
;
static
constexpr
int
ElementsPerAccessC
=
1
;
static
constexpr
int
ThreadblockK
=
8
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
using
Operator
=
cutlass
::
arch
::
OpMultiplyAdd
;
};
// ======================= Turing Traits ==============================
// Note that turing does not have native bfloat support so weights and activations will be casted to fp16
// and compute will happen in fp16 then will be converted for bf16 output.
template
<
typename
TypeA
,
typename
TypeB
>
struct
MixedGemmArchTraits
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm75
,
typename
cutlass
::
platform
::
enable_if
<
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
half_t
>::
value
||
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
bfloat16_t
>::
value
>::
type
>
{
private:
using
LayoutDetails
=
LayoutDetailsB
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm75
>
;
public:
static
constexpr
int
ThreadblockK
=
LayoutDetails
::
ThreadblockK
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
typename
LayoutDetails
::
Layout
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
LayoutDetails
::
ElementsPerAccess
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
8
>
;
using
Operator
=
typename
LayoutDetails
::
Operator
;
};
// ======================= Ampere Traits ==============================
template
<
typename
TypeA
,
typename
TypeB
>
struct
MixedGemmArchTraits
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm80
,
typename
cutlass
::
platform
::
enable_if
<
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
half_t
>::
value
||
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
bfloat16_t
>::
value
>::
type
>
{
private:
using
LayoutDetails
=
LayoutDetailsB
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm80
>
;
public:
static
constexpr
int
ThreadblockK
=
LayoutDetails
::
ThreadblockK
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
typename
LayoutDetails
::
Layout
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
LayoutDetails
::
ElementsPerAccess
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
16
>
;
using
Operator
=
typename
LayoutDetails
::
Operator
;
};
// ======================= Ada Traits ==============================
template
<
typename
TypeA
,
typename
TypeB
>
struct
MixedGemmArchTraits
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm89
,
typename
cutlass
::
platform
::
enable_if
<
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
half_t
>::
value
||
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
bfloat16_t
>::
value
>::
type
>
{
private:
using
LayoutDetails
=
LayoutDetailsB
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm89
>
;
public:
static
constexpr
int
ThreadblockK
=
LayoutDetails
::
ThreadblockK
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
using
LayoutB
=
typename
LayoutDetails
::
Layout
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
LayoutDetails
::
ElementsPerAccess
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
256
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
>
;
using
Operator
=
typename
LayoutDetails
::
Operator
;
};
// FP8 A/B = fp8, C/D = fp32
template
<
typename
TypeA
,
typename
TypeB
>
struct
MixedGemmArchTraits
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm89
,
typename
cutlass
::
platform
::
enable_if
<
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
float_e4m3_t
>::
value
||
cutlass
::
platform
::
is_same
<
TypeA
,
cutlass
::
float_e5m2_t
>::
value
>::
type
>
{
private:
using
LayoutDetails
=
LayoutDetailsB
<
TypeA
,
TypeB
,
cutlass
::
arch
::
Sm89
>
;
public:
static
constexpr
int
ThreadblockK
=
LayoutDetails
::
ThreadblockK
;
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
AccType
=
float
;
// be careful, TypeC should align with HopperGroupedGemmInput::OutputTypeAdaptor_t<TypeA>
using
TypeC
=
__nv_bfloat16
;
using
LayoutB
=
typename
LayoutDetails
::
Layout
;
static
constexpr
int
ElementsPerAccessA
=
128
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
;
static
constexpr
int
ElementsPerAccessB
=
LayoutDetails
::
ElementsPerAccess
;
static
constexpr
int
ElementsPerAccessC
=
128
/
cutlass
::
sizeof_bits
<
TypeC
>::
value
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
256
/
cutlass
::
sizeof_bits
<
TypeA
>::
value
>
;
using
Operator
=
typename
LayoutDetails
::
Operator
;
};
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
sgl-kernel/3rdparty/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/kernel/default_int8_traits.h
deleted
100644 → 0
View file @
9829e77e
/*
* SPDX-FileCopyrightText: Copyright (c) 2022-2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: Apache-2.0
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#pragma once
#include "cutlass/arch/arch.h"
#include "cutlass/arch/mma.h"
#include "cutlass/cutlass.h"
#include "cutlass/gemm/gemm.h"
#include "cutlass/layout/matrix.h"
namespace
cutlass
{
namespace
gemm
{
namespace
kernel
{
template
<
typename
arch
>
struct
Int8GemmArchTraits
{
using
OperatorClass
=
cutlass
::
arch
::
OpClassSimt
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
1
,
1
,
1
>
;
};
// ======================= Turing Traits ==============================
template
<
>
struct
Int8GemmArchTraits
<
cutlass
::
arch
::
Sm75
>
{
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
8
,
8
,
16
>
;
};
// ======================= Ampere Traits ==============================
template
<
>
struct
Int8GemmArchTraits
<
cutlass
::
arch
::
Sm80
>
{
using
OperatorClass
=
cutlass
::
arch
::
OpClassTensorOp
;
using
InstructionShape
=
cutlass
::
gemm
::
GemmShape
<
16
,
8
,
32
>
;
};
}
// namespace kernel
}
// namespace gemm
}
// namespace cutlass
Prev
1
2
3
4
5
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment