Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
dec32dc6
Commit
dec32dc6
authored
Jan 31, 2025
by
ThomasNing
Browse files
Finish the feature and merge with develop on the computeV2
parents
71352c44
c5fff071
Changes
215
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
936 additions
and
595 deletions
+936
-595
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+4
-15
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+4
-110
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+35
-18
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+398
-289
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+25
-0
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+19
-19
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+4
-4
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+13
-13
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+3
-3
include/ck_tile/ops/rmsnorm2d.hpp
include/ck_tile/ops/rmsnorm2d.hpp
+1
-0
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
...ude/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
+169
-28
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
...norm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
+5
-5
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
+66
-15
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
...ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
+13
-13
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
...ps/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
+73
-18
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
...e/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
+54
-0
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
...ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
+15
-13
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
...ude/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
+15
-15
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp
...othquant/pipeline/smoothquant_pipeline_default_policy.hpp
+2
-2
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
...ps/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
+18
-15
No files found.
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -39,17 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -39,17 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
static
constexpr
bool
kPadK
=
Problem
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
return
integer_divide_ceil
(
sizeof
(
ADataType
)
*
Policy
::
template
MakeALdsBlockDescriptor
<
Problem
>().
get_element_space_size
(),
16
)
*
16
+
sizeof
(
BDataType
)
*
Policy
::
template
MakeBLdsBlockDescriptor
<
Problem
>().
get_element_space_size
();
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
...
@@ -150,7 +139,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -150,7 +139,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledARegBlockD
istribution
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
...
@@ -164,7 +153,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -164,7 +153,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledBRegBlockD
istribution
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
...
@@ -201,7 +190,7 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -201,7 +190,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockD
escriptor
<
Problem
>());
Policy
::
template
MakeShuffledBRegBlockD
istribution
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
dec32dc6
...
@@ -18,37 +18,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -18,37 +18,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static
constexpr
bool
TransposeC
=
true
;
static
constexpr
bool
TransposeC
=
true
;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif
1
// 3d + padding
// 3d + padding
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
...
@@ -58,7 +27,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -58,7 +27,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
a_lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
kKPerBlock
/
8
>
{},
number
<
kMPerBlock
>
{},
number
<
8
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
make_tuple
(
number
<
(
kMPerBlock
+
1
)
*
8
>
{},
number
<
8
>
{},
number
<
1
>
{}),
...
@@ -127,87 +95,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -127,87 +95,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
}
#elif 1
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
a_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
ADataType
);
constexpr
auto
a_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
a_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
a_lds_block_desc_m_k
=
transform_tensor_descriptor
(
a_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kMPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
a_lds_block_desc_m_k
;
}
// fake XOR
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
using
namespace
ck_tile
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
auto
b_lds_block_desc_d1_d2_d3
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{},
number
<
kKPerBlock
>
{}),
number
<
kKPerBlock
>
{});
constexpr
index_t
kK1
=
16
/
sizeof
(
BDataType
);
constexpr
auto
b_lds_block_desc_d4_d5_d6
=
transform_tensor_descriptor
(
b_lds_block_desc_d1_d2_d3
,
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
kKPerBlock
>
{}),
kK1
),
make_pass_through_transform
(
2
)),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{}));
constexpr
auto
b_lds_block_desc_n_k
=
transform_tensor_descriptor
(
b_lds_block_desc_d4_d5_d6
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
kNPerBlock
/
2
>
{},
number
<
2
>
{})),
make_pass_through_transform
(
kKPerBlock
)),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc_n_k
;
}
#endif
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
...
@@ -273,7 +168,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -273,7 +168,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
static_assert
(
M0
*
M1
*
M2
==
MPerBlock
,
"Incorrect M0, M2, M1 configuration! "
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"
);
"M0, M1, M2 must cover whole MPerBlock!"
);
return
make_static_tile_distribution
(
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
...
@@ -394,7 +288,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -394,7 +288,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockD
escriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockD
istribution
()
{
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
...
@@ -442,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -442,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockD
escriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockD
istribution
()
{
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
dec32dc6
...
@@ -3,6 +3,7 @@
...
@@ -3,6 +3,7 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -11,10 +12,10 @@ template <typename ADataType_,
...
@@ -11,10 +12,10 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
>
typename
Traits_
>
struct
GemmPipelineProblemBase
struct
GemmPipelineProblemBase
{
{
using
Gemm
Traits
=
remove_cvref_t
<
TileGemm
Traits_
>
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
...
@@ -22,21 +23,21 @@ struct GemmPipelineProblemBase
...
@@ -22,21 +23,21 @@ struct GemmPipelineProblemBase
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Gemm
Traits
::
CLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Gemm
Traits
::
kPadM
;
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Gemm
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Gemm
Traits
::
kPadK
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
bool
isDoubleSmemBuffer
=
GemmTraits
::
isDoubleSmemBuffer
;
static
constexpr
bool
isDoubleSmemBuffer
=
GemmTraits
::
isDoubleSmemBuffer
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
auto
Scheduler
=
GemmPipelineScheduler
::
Default
;
static
constexpr
index_t
VectorLoadSize
=
Traits
::
_VectorSize
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
...
@@ -130,27 +131,43 @@ template <typename ADataType_,
...
@@ -130,27 +131,43 @@ template <typename ADataType_,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
>
typename
Traits_
>
using
GemmPipelineProblem
=
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemm
Traits_
>
;
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
Traits_
>
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemm
Traits_
,
typename
Traits_
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
struct
UniversalGemmPipelineProblem
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
{
{
using
Traits
=
remove_cvref_t
<
Traits_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
ALayout
=
remove_cvref_t
<
typename
Traits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Traits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
Traits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadM
=
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Traits
::
kPadN
;
static
constexpr
bool
kPadK
=
Traits
::
kPadK
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
bool
TransposeC
=
Traits
::
TransposeC
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2024
-2025
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -15,30 +16,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -15,30 +16,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
auto
I2
=
number
<
2
>
{};
static
constexpr
bool
TransposeC
=
true
;
static
constexpr
auto
ATileAccessPattern
=
tile_distribution_pattern
::
thread_raked
;
static
constexpr
auto
BTileAccessPattern
=
tile_distribution_pattern
::
thread_raked
;
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorLoadSize
()
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template
<
typename
Problem
,
typename
DataType
,
index_t
MNPerBlock
,
index_t
XPerTile
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetGlobalVectorLoadSize
()
{
{
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
constexpr
index_t
elements_per_thread
=
MNPerBlock
*
KPerBlock
/
BlockSize
;
if
constexpr
(
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
// Assume DataType is even!
if
constexpr
(
XPerTile
%
(
16
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
16
/
sizeof
(
DataType
))
==
0
)
{
{
return
(
16
/
sizeof
(
DataType
));
return
(
16
/
sizeof
(
DataType
));
}
}
else
if
constexpr
(
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
else
if
constexpr
(
XPerTile
%
(
8
/
sizeof
(
DataType
))
==
0
&&
elements_per_thread
%
(
8
/
sizeof
(
DataType
))
==
0
)
{
{
return
(
8
/
sizeof
(
DataType
));
return
(
8
/
sizeof
(
DataType
));
}
}
else
if
constexpr
(
elements_per_thread
%
(
4
/
sizeof
(
DataType
))
==
0
&&
else
if
constexpr
(
sizeof
(
DataType
)
>=
4
&&
XPerTile
%
(
4
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>
=
4
)
elements_per_thread
%
(
4
/
sizeof
(
DataType
)
)
=
=
0
)
{
{
return
(
4
/
sizeof
(
DataType
));
return
(
4
/
sizeof
(
DataType
));
}
}
else
if
constexpr
(
elements_per_thread
%
(
2
/
sizeof
(
DataType
))
==
0
&&
else
if
constexpr
(
sizeof
(
DataType
)
>=
2
&&
XPerTile
%
(
2
/
sizeof
(
DataType
))
==
0
&&
sizeof
(
DataType
)
>
=
2
)
elements_per_thread
%
(
2
/
sizeof
(
DataType
)
)
=
=
0
)
{
{
return
(
2
/
sizeof
(
DataType
));
return
(
2
/
sizeof
(
DataType
));
}
}
...
@@ -48,6 +62,126 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -48,6 +62,126 @@ struct UniversalGemmPipelineAgBgCrPolicy
}
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeA
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
GetGlobalVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
,
KPerBlock
>
();
}
else
{
return
GetGlobalVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
,
MPerBlock
>
();
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeB
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
GetGlobalVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
,
NPerBlock
>
();
}
else
{
return
GetGlobalVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
,
KPerBlock
>
();
}
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
using
BlockGemm
=
remove_cvref_t
<
decltype
(
GetBlockGemm
<
Problem
>
())
>
;
using
WG
=
typename
BlockGemm
::
WarpGemm
;
constexpr
bool
TransposeC
=
Problem
::
TransposeC
;
using
CLayout
=
typename
Problem
::
CLayout
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
TransposeC
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BlockGemm
=
decltype
(
GetBlockGemm
<
Problem
>
());
constexpr
index_t
KPack
=
BlockGemm
::
Traits
::
KPack
;
return
KPack
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeALdsBlockDescriptor
()
{
{
...
@@ -56,7 +190,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -56,7 +190,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
Get
VectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
constexpr
index_t
KPack
=
Get
SmemPackA
<
Problem
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
DataTypeSize
=
sizeof
(
ADataType
);
constexpr
auto
MLdsLayer
=
constexpr
auto
MLdsLayer
=
...
@@ -99,54 +233,193 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -99,54 +233,193 @@ struct UniversalGemmPipelineAgBgCrPolicy
return
a_lds_block_desc
;
return
a_lds_block_desc
;
}
}
/**
* @brief Create LDS block descriptor for B tensor.
*
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBLdsBlockDescriptor
()
{
{
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
#if 1
constexpr
auto
NLdsLayer
=
// if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
{
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
constexpr
auto
BK0
=
number
<
KPerBlock
/
KPack
>
{};
make_tuple
(
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{},
constexpr
auto
DataTypeSize
=
sizeof
(
BDataType
);
number
<
NPerBlock
/
NLdsLayer
>
{},
constexpr
auto
NLdsLayer
=
number
<
KPack
>
{}),
(
32
*
4
/
KPerBlock
/
DataTypeSize
)
<
1
?
1
:
(
32
*
4
/
KPerBlock
/
DataTypeSize
);
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
NLdsLayer
>
{},
number
<
1
>
{}),
number
<
KPack
>
{},
constexpr
auto
b_lds_block_desc_0
=
make_naive_tensor_descriptor
(
number
<
1
>
{});
make_tuple
(
BK0
*
number
<
NLdsLayer
>
{},
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
KPack
>
{}),
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
make_tuple
(
number
<
KPack
>
{},
number
<
KPerBlock
*
NLdsLayer
>
{},
number
<
1
>
{}),
b_lds_block_desc_0
,
number
<
KPack
>
{},
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
1
>
{});
number
<
KPerBlock
/
KPack
*
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
KPack
>
{})),
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
b_lds_block_desc_0
,
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
make_tuple
(
make_xor_transform
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
BK0
*
number
<
NLdsLayer
>
{})),
constexpr
auto
b_lds_block_desc_xk0_mnldslayer_mn_xk1
=
transform_tensor_descriptor
(
make_pass_through_transform
(
number
<
KPack
>
{})),
b_lds_block_desc_permuted
,
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
sequence
<
1
,
0
>
{},
sequence
<
2
>
{}));
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
NLdsLayer
>
{})),
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
constexpr
auto
b_lds_block_desc_bk0_nldslayer_n_bk1
=
transform_tensor_descriptor
(
make_pass_through_transform
(
number
<
KPack
>
{})),
b_lds_block_desc_permuted
,
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
make_tuple
(
make_unmerge_transform
(
make_tuple
(
BK0
,
number
<
NLdsLayer
>
{})),
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
make_pass_through_transform
(
number
<
NPerBlock
/
NLdsLayer
>
{}),
make_pass_through_transform
(
number
<
KPack
>
{})),
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}),
b_lds_block_desc_xk0_mnldslayer_mn_xk1
,
make_tuple
(
sequence
<
0
,
2
>
{},
sequence
<
1
>
{},
sequence
<
3
>
{}));
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{})),
constexpr
auto
b_lds_block_desc
=
transform_tensor_descriptor
(
make_merge_transform_v3_division_mod
(
b_lds_block_desc_bk0_nldslayer_n_bk1
,
make_tuple
(
number
<
KPerBlock
/
KPack
>
{},
number
<
KPack
>
{}))),
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
number
<
NPerBlock
/
NLdsLayer
>
{},
number
<
NLdsLayer
>
{})),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
make_merge_transform_v3_division_mod
(
make_tuple
(
BK0
,
number
<
KPack
>
{}))),
return
b_lds_block_desc
;
make_tuple
(
sequence
<
1
,
2
>
{},
sequence
<
0
,
3
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
b_lds_block_desc
;
}
#else
else
// B is Row Major
{
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
KPerBlock
,
NPerBlock
,
VecLoadSize
,
BTileAccessPattern
>
;
constexpr
auto
BK0
=
number
<
TileEncodingPattern
::
X1
>
{};
constexpr
auto
BK1
=
number
<
TileEncodingPattern
::
Y0
>
{};
// constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr
auto
N0
=
TileEncodingPattern
::
X0
;
constexpr
auto
N1
=
NPerBlock
/
N0
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
constexpr
auto
NPerXdl
=
number
<
WarpTile
::
at
(
I1
)
>
{};
// constexpr auto KThreadWrite =
// BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr
auto
KThreadWrite
=
TileEncodingPattern
::
Y2
;
constexpr
auto
K0PerThreadWrite
=
BK0
/
KThreadWrite
;
constexpr
auto
KThreadRead
=
64
/
NPerXdl
;
constexpr
auto
K0PerThreadRead
=
BK0
/
KThreadRead
;
constexpr
auto
kfold
=
(
BK1
*
N0
*
sizeof
(
BDataType
)
>
128
)
?
1
:
128
/
(
BK1
*
N0
*
sizeof
(
BDataType
));
constexpr
auto
KThreadReadPerm
=
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
>
1
?
KThreadRead
/
(
kfold
*
K0PerThreadWrite
/
K0PerThreadRead
)
:
KThreadRead
;
// 1<=npair<=n0
constexpr
auto
npair
=
(
BK1
*
NPerXdl
*
sizeof
(
BDataType
)
>
128
)
?
1
:
((
128
/
(
BK1
*
NPerXdl
*
sizeof
(
BDataType
)))
>
N0
?
N0
:
128
/
(
BK1
*
NPerXdl
*
sizeof
(
BDataType
)));
constexpr
auto
b_lds_block_desc
=
make_naive_tensor_descriptor_packed
(
make_tuple
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
K0PerThreadWrite
>
{},
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{},
number
<
npair
>
{},
BK1
));
constexpr
auto
b_lds_block_desc_permuted
=
transform_tensor_descriptor
(
b_lds_block_desc
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_xor_transform
(
make_tuple
(
number
<
KThreadReadPerm
*
N1
>
{},
number
<
kfold
*
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
BK1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
,
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}));
constexpr
auto
b_lds_block_desc_unmerged
=
transform_tensor_descriptor
(
b_lds_block_desc_permuted
,
make_tuple
(
make_pass_through_transform
(
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{}),
make_pass_through_transform
(
number
<
K0PerThreadWrite
>
{}),
make_unmerge_transform
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
N1
>
{})),
make_unmerge_transform
(
make_tuple
(
number
<
kfold
>
{},
number
<
N0
/
npair
>
{})),
make_pass_through_transform
(
number
<
npair
>
{}),
make_pass_through_transform
(
BK1
)),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
3
>
{},
sequence
<
4
>
{},
sequence
<
5
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
2
>
{},
sequence
<
0
,
3
>
{},
sequence
<
4
,
5
>
{},
sequence
<
6
>
{},
sequence
<
7
>
{}));
// constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_lds_block_desc_unmerged,
// make_tuple(make_merge_transform_v3_division_mod(
// make_tuple(number<KThreadReadPerm>{},
// number<KThreadWrite / kfold / KThreadReadPerm>{},
// number<kfold>{},
// number<K0PerThreadWrite>{})),
// make_merge_transform_v3_division_mod(
// make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
// make_pass_through_transform(BK1)),
// make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
// make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
constexpr
auto
b_lds_block_desc_kn
=
transform_tensor_descriptor
(
b_lds_block_desc_unmerged
,
make_tuple
(
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
KThreadReadPerm
>
{},
number
<
KThreadWrite
/
kfold
/
KThreadReadPerm
>
{},
number
<
kfold
>
{},
number
<
K0PerThreadWrite
>
{},
BK1
)),
make_merge_transform_v3_division_mod
(
make_tuple
(
number
<
N0
/
npair
>
{},
number
<
npair
>
{},
number
<
N1
>
{}))),
make_tuple
(
sequence
<
0
,
1
,
4
,
2
,
7
>
{},
sequence
<
5
,
6
,
3
>
{}),
make_tuple
(
sequence
<
1
>
{},
sequence
<
0
>
{}));
// return b_lds_block_desc_bk0_n_bk1;
return
b_lds_block_desc_kn
;
// constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
// make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
// number<KPack>{},
// number<1>{});
// constexpr auto b_lds_block_desc = transform_tensor_descriptor(
// b_lds_block_desc_bk0_n_bk1,
// make_tuple(make_pass_through_transform(number<NPerBlock>{}),
// make_merge_transform_v3_division_mod(make_tuple(BK0,
// number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc;
}
#endif
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -179,291 +452,127 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -179,291 +452,127 @@ struct UniversalGemmPipelineAgBgCrPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
// Tile: MPerBlock X KPerBlock
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
{
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
M0
=
MPerBlock
/
M1
;
MPerBlock
,
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
KPerBlock
,
static_assert
(
total_pixels
%
M1
==
0
);
VecLoadSize
,
constexpr
index_t
K3
=
total_pixels
/
M1
;
ATileAccessPattern
>
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
}
// Tile: KPerBlock X MPerBlock
else
else
{
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
K0
=
KPerBlock
/
K1
;
KPerBlock
,
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
MPerBlock
,
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
VecLoadSize
,
{
ATileAccessPattern
>
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
N
PerBlock
=
Problem
::
BlockGemmShape
::
k
N
;
constexpr
index_t
K
PerBlock
=
Problem
::
BlockGemmShape
::
k
K
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
()
;
// Tile: KPerBlock X NPerBlock
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
N0
=
NPerBlock
/
N1
;
KPerBlock
,
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
NPerBlock
,
static_assert
(
total_pixels
%
N1
==
0
);
VecLoadSize
,
constexpr
index_t
K3
=
total_pixels
/
N1
;
BTileAccessPattern
>
;
constexpr
index_t
KPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
}
// Tile: NPerBlock X KPerBlock
else
else
{
{
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
NPerBlock
,
constexpr
index_t
K0
=
KPerBlock
/
K1
;
KPerBlock
,
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
VecLoadSize
,
// coalesce reading for each blocks
BTileAccessPattern
>
;
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
return
TileEncodingPattern
::
Make2DStaticTileDistribution
();
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledAReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledAReg
TileDistribution
()
{
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeA
<
Problem
>
();
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
KPerBlock
,
static_assert
(
total_pixels
%
M1
==
0
);
MPerBlock
,
constexpr
index_t
K3
=
total_pixels
/
M1
;
VecLoadSize
,
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
ADataType
,
MPerBlock
>
();
ATileAccessPattern
>
;
static_assert
(
kKPack
%
K3
==
0
);
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBReg
BlockDescriptor
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBReg
TileDistribution
()
{
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
VecLoadSize
=
GetVectorSizeB
<
Problem
>
();
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
BlockSize
,
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
KPerBlock
,
static_assert
(
total_pixels
%
N1
==
0
);
NPerBlock
,
constexpr
index_t
K3
=
total_pixels
/
N1
;
VecLoadSize
,
constexpr
index_t
kKPack
=
GetVectorLoadSize
<
Problem
,
BDataType
,
NPerBlock
>
();
BTileAccessPattern
>
;
static_assert
(
kKPack
%
K3
==
0
);
return
TileEncodingPattern
::
MakeShuffled2DStaticTileDistribution
();
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
TransposeC
;
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
IsTransposeC
()
{
return
Problem
::
TransposeC
;
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockGemm
()
{
{
using
AccDataType
=
float
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
BlockWarps
=
typename
Problem
::
BlockGemmShape
::
BlockWarps
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpTile
=
typename
Problem
::
BlockGemmShape
::
WarpTile
;
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
BDataType
,
Acc
DataType
,
typename
Problem
::
C
DataType
,
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I0
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I1
),
WarpTile
::
at
(
I2
),
WarpTile
::
at
(
I2
),
TransposeC
>
;
Problem
::
TransposeC
>
;
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
using
BlockGemmPolicy
=
BlockGemmASmemBSmemCRegV1CustomPolicy
<
typename
Problem
::
ADataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
BDataType
,
typename
Problem
::
CDataType
,
typename
Problem
::
CDataType
,
BlockWarps
,
BlockWarps
,
WarpGemm
>
;
WarpGemm
>
;
return
Block
GemmASmemBSmemCRegV1
<
Problem
,
BlockGemmPolicy
>
{};
return
Block
UniversalGemmAsBsCr
<
Problem
,
BlockGemmPolicy
>
{};
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
dec32dc6
...
@@ -22,11 +22,36 @@ struct TileGemmTraits
...
@@ -22,11 +22,36 @@ struct TileGemmTraits
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
static
constexpr
bool
isDoubleSmemBuffer
=
isDoubleSmemBuffer_
;
// TODO this can't be hardcoded here! Should be in policy!
static
constexpr
int
_VectorSize
=
16
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
TransposeC
=
false
;
};
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
kPadK_
,
typename
ALayout_
,
typename
BLayout_
,
typename
CLayout_
,
bool
TransposeC_
=
false
>
struct
TileGemmUniversalTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadK
=
kPadK_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
CLayout
=
CLayout_
;
static
constexpr
bool
TransposeC
=
TransposeC_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
...
@@ -14,7 +14,7 @@ struct Layernorm2dFwdHostArgs
{
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_
x
_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_
sm
_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
@@ -43,16 +43,16 @@ struct Layernorm2dFwd
...
@@ -43,16 +43,16 @@ struct Layernorm2dFwd
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XBiasDataType
=
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
XBiasDataType
=
remove_cvref_t
<
typename
Problem
::
XBiasDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
BetaDataType
=
remove_cvref_t
<
typename
Problem
::
BetaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
MeanDataType
=
remove_cvref_t
<
typename
Problem
::
MeanDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
InvStdDataType
=
remove_cvref_t
<
typename
Problem
::
InvStdDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
// for simplicity, shortcut input/output type is same as X
using
XResidualDataType
=
XDataType
;
using
XResidualDataType
=
XDataType
;
...
@@ -84,7 +84,7 @@ struct Layernorm2dFwd
...
@@ -84,7 +84,7 @@ struct Layernorm2dFwd
{
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_
x
_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_
sm
_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_x_bias
;
// [1, n], bias, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
const
void
*
p_beta
;
// [1, n], beta, prec same as input
...
@@ -111,7 +111,7 @@ struct Layernorm2dFwd
...
@@ -111,7 +111,7 @@ struct Layernorm2dFwd
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_x_residual
,
hargs
.
p_
x
_scale
,
hargs
.
p_
sm
_scale
,
hargs
.
p_x_bias
,
hargs
.
p_x_bias
,
hargs
.
p_gamma
,
hargs
.
p_gamma
,
hargs
.
p_beta
,
hargs
.
p_beta
,
...
@@ -171,7 +171,7 @@ struct Layernorm2dFwd
...
@@ -171,7 +171,7 @@ struct Layernorm2dFwd
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
}
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
X
ScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
Smooth
ScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
}
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
if
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
...
@@ -356,18 +356,18 @@ struct Layernorm2dFwd
...
@@ -356,18 +356,18 @@ struct Layernorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
}();
auto
x
_scale_window
=
[
&
]()
{
auto
sm
_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
{
const
auto
win_
=
[
&
]()
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
const
X
ScaleDataType
*>
(
kargs
.
p_
x
_scale
),
static_cast
<
const
Smooth
ScaleDataType
*>
(
kargs
.
p_
sm
_scale
),
make_tuple
(
kargs
.
n
),
make_tuple
(
kargs
.
n
),
number
<
Vector_N
>
{});
number
<
Vector_N
>
{});
return
pad_tensor_view
(
tmp_0_
,
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_N
>
{}),
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
//
x
_scale no need pad
sequence
<
false
>
{});
//
sm
_scale no need pad
}();
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
}
...
@@ -405,7 +405,7 @@ struct Layernorm2dFwd
...
@@ -405,7 +405,7 @@ struct Layernorm2dFwd
y_residual_window
,
y_residual_window
,
mean_window
,
mean_window
,
inv_std_window
,
inv_std_window
,
x
_scale_window
,
sm
_scale_window
,
y_scale_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
kargs
.
n
,
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -64,7 +64,7 @@ struct Layernorm2dFwdPipelineOnePass
typename
YResidualWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
X
ScaleWindow
,
typename
Smooth
ScaleWindow
,
typename
YScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
...
@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -76,7 +76,7 @@ struct Layernorm2dFwdPipelineOnePass
const
YResidualWindow
&
y_residual_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
InvStdWindow
&
inv_std_window
,
const
X
ScaleWindow
&
x
_scale_window_
,
const
Smooth
ScaleWindow
&
sm
_scale_window_
,
YScaleWindow
&
y_scale_window
,
YScaleWindow
&
y_scale_window
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
...
@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -190,7 +190,7 @@ struct Layernorm2dFwdPipelineOnePass
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
{
Epilogue
{}(
y_window_
,
x
_scale_window_
,
y_scale_window
,
ln
,
smem
);
Epilogue
{}(
y_window_
,
sm
_scale_window_
,
y_scale_window
,
ln
,
smem
);
}
}
else
else
Epilogue
{}(
y_window_
,
ln
);
Epilogue
{}(
y_window_
,
ln
);
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -15,23 +15,23 @@ template <typename XDataType_,
...
@@ -15,23 +15,23 @@ template <typename XDataType_,
typename
YDataType_
,
typename
YDataType_
,
typename
MeanDataType_
,
typename
MeanDataType_
,
typename
InvStdDataType_
,
typename
InvStdDataType_
,
typename
X
ScaleDataType_
,
typename
Smooth
ScaleDataType_
,
typename
YScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
typename
Traits_
>
typename
Traits_
>
struct
Layernorm2dFwdPipelineProblem
struct
Layernorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XBiasDataType
=
remove_cvref_t
<
XBiasDataType_
>
;
using
XBiasDataType
=
remove_cvref_t
<
XBiasDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
BetaDataType
=
remove_cvref_t
<
BetaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
MeanDataType
=
remove_cvref_t
<
MeanDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
InvStdDataType
=
remove_cvref_t
<
InvStdDataType_
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
X
ScaleDataType_
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
Smooth
ScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -63,7 +63,7 @@ struct Layernorm2dFwdPipelineTwoPass
typename
YResidualWindow
,
typename
YResidualWindow
,
typename
MeanWindow
,
typename
MeanWindow
,
typename
InvStdWindow
,
typename
InvStdWindow
,
typename
X
ScaleWindow
,
typename
Smooth
ScaleWindow
,
typename
YScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
...
@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -75,7 +75,7 @@ struct Layernorm2dFwdPipelineTwoPass
const
YResidualWindow
&
y_residual_window_
,
const
YResidualWindow
&
y_residual_window_
,
MeanWindow
&
mean_window
,
MeanWindow
&
mean_window
,
InvStdWindow
&
inv_std_window
,
InvStdWindow
&
inv_std_window
,
const
X
ScaleWindow
&
/*
x
_scale_window*/
,
const
Smooth
ScaleWindow
&
/*
sm
_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
...
...
include/ck_tile/ops/rmsnorm2d.hpp
View file @
dec32dc6
...
@@ -8,5 +8,6 @@
...
@@ -8,5 +8,6 @@
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/rmsnorm2d/kernel/rmsnorm2d_fwd_kernel.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
// host side args
// host side args
struct
Rmsnorm2dFwdHostArgs
struct
Rmsnorm2dFwdHostArgs
{
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
const
void
*
p_x_residual
;
// [m ,n], shortcut input, prec same as input, nullptr if not used
const
void
*
p_sm_scale
;
// [1 ,n], smooth scale input, fp32, nullptr if not used
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_y
;
// [m, n], output, fp16/bf16
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
void
*
p_y_residual
;
// [m, n], shortcut output, prec same as input, nullptr if not used
void
*
p_y_scale
;
// [m, 1], output a dynamic quant per row, nullptr if not used
void
*
p_invRms
;
// [m, 1], output inv-rms, prec same as input, nullptr if not used
float
epsilon
;
float
epsilon
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
template
<
typename
Pipeline_
,
typename
Epilogue_
>
struct
Rmsnorm2dFwd
struct
Rmsnorm2dFwd
{
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Epilogue
=
remove_cvref_t
<
Epilogue_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
InvRmsDataType
=
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
typename
Problem
::
SmoothScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
// for simplicity, shortcut input/output type is same as X
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kTwoPass
=
Problem
::
kTwoPass
;
static
constexpr
bool
kTwoPass
=
Problem
::
Traits
::
kTwoPass
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
...
@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
...
@@ -56,29 +73,43 @@ struct Rmsnorm2dFwd
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_x_residual
;
const
void
*
p_sm_scale
;
const
void
*
p_gamma
;
const
void
*
p_gamma
;
void
*
p_y
;
void
*
p_y
;
void
*
p_y_residual
;
void
*
p_y_scale
;
void
*
p_invRms
;
void
*
p_invRms
;
float
epsilon
;
float
epsilon
;
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
using
Hargs
=
Rmsnorm2dFwdHostArgs
;
using
Hargs
=
Rmsnorm2dFwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_x_residual
,
hargs
.
p_sm_scale
,
hargs
.
p_gamma
,
hargs
.
p_gamma
,
hargs
.
p_y
,
hargs
.
p_y
,
hargs
.
p_y_residual
,
hargs
.
p_y_scale
,
hargs
.
p_invRms
,
hargs
.
p_invRms
,
hargs
.
epsilon
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
m
,
hargs
.
n
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
x_stride
,
hargs
.
xr_stride
,
hargs
.
y_stride
,
hargs
.
yr_stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
...
@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
...
@@ -95,6 +126,7 @@ struct Rmsnorm2dFwd
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
// clang-format on
// in byte
// in byte
...
@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
...
@@ -102,24 +134,41 @@ struct Rmsnorm2dFwd
CK_TILE_HOST
static
std
::
string
GetName
()
CK_TILE_HOST
static
std
::
string
GetName
()
{
{
#define _SS_ std::string
#define _TS_ std::to_string
// clang-format off
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
std
::
string
n
;
if
(
kFusedAdd
!=
Rmsnorm2dFusedAddEnum
::
NO_ADD
)
n
+=
_SS_
(
"_"
)
+
Rmsnorm2dFusedAddEnumName
<
kFusedAdd
>::
name
;
if
(
kFusedQuant
!=
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
)
n
+=
_SS_
(
"_"
)
+
Rmsnorm2dFusedQuantEnumName
<
kFusedQuant
>::
name
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveInvRms
)
n
+=
"_rms"
;
if
(
kSaveInvRms
)
n
+=
"_rms"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
if
(
kTwoPass
)
n
+=
"_2p"
;
return
n
;
}();
return
n
;
}();
#define _SS_ std::string
auto
prec_str
=
[
&
]
()
{
#define _TS_ std::to_string
std
::
string
base_str
=
_SS_
(
t2s
<
XDataType
>::
name
);
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
if
(
!
std
::
is_same_v
<
XDataType
,
YDataType
>
)
{
base_str
+=
_SS_
(
"_"
)
+
_SS_
(
t2s
<
YDataType
>::
name
);
}
if
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sx"
)
+
_SS_
(
t2s
<
SmoothScaleDataType
>::
name
);
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
if
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
base_str
+=
_SS_
(
"_sy"
)
+
_SS_
(
t2s
<
YScaleDataType
>::
name
);
}
return
base_str
;
}();
return
_SS_
(
"rmsnorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
// clang-format on
#undef _SS_
#undef _TS_
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
...
@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
...
@@ -130,7 +179,7 @@ struct Rmsnorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
...
@@ -140,6 +189,29 @@ struct Rmsnorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
const
auto
x_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
xr_stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
...
@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
...
@@ -158,7 +230,7 @@ struct Rmsnorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
...
@@ -168,6 +240,28 @@ struct Rmsnorm2dFwd
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
auto
y_residual_window
=
[
&
]()
{
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
yr_stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}
}();
auto
inv_rms_window
=
[
&
]()
{
auto
inv_rms_window
=
[
&
]()
{
if
constexpr
(
kSaveInvRms
)
if
constexpr
(
kSaveInvRms
)
{
{
...
@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
...
@@ -187,15 +281,62 @@ struct Rmsnorm2dFwd
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}();
}();
auto
sm_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
const
SmoothScaleDataType
*>
(
kargs
.
p_sm_scale
),
make_tuple
(
kargs
.
n
),
number
<
Vector_N
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
false
>
{});
// sm_scale no need pad
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_N
>
{}));
}
}();
auto
y_scale_window
=
[
&
]()
{
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
||
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
const
auto
win_
=
[
&
]()
{
const
auto
tmp_0_
=
make_naive_tensor_view_packed
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_y_scale
),
make_tuple
(
kargs
.
m
),
number
<
1
>
{});
return
pad_tensor_view
(
tmp_0_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
}();
return
make_tile_window
(
win_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}
else
{
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{}));
}
}();
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
Pipeline
{}(
x_window
,
x_residual_window
,
gamma_window
,
gamma_window
,
y_window
,
y_window
,
y_residual_window
,
inv_rms_window
,
inv_rms_window
,
sm_scale_window
,
y_scale_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
kargs
.
n
,
smem
);
smem
,
Epilogue
{});
}
}
};
};
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp
View file @
dec32dc6
...
@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -45,7 +45,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
return
BlockReduce2d
<
P_
>
{};
...
@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -54,7 +54,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
return
BlockReduce2dSync
<
P_
>
{};
...
@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -63,7 +63,7 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
...
@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
...
@@ -74,13 +74,13 @@ struct Rmsnorm2dFwdPipelineDefaultPolicy
{
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
X
DataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
Compute
DataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
MakeXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_one_pass.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineOnePass
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineOnePass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
InvRmsWindow
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
YWindow
&
y_window_
,
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
sm_scale_window_
,
YScaleWindow
&
y_scale_window_
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
{
const
auto
x_window
=
const
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
const
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
...
@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -62,13 +84,31 @@ struct Rmsnorm2dFwdPipelineOnePass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
// load gamma (TODO: support no gamma?)
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
}
}
// compute mean square each-thread->cross-lane->cross-warp
// compute mean square each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
auto
square_sum
=
block_reduce2d
(
acc
,
x
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
...
@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
...
@@ -83,19 +123,30 @@ struct Rmsnorm2dFwdPipelineOnePass
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
store_tile
(
inv_rms_window
,
cast_tile
<
InvRmsDataType
>
(
inv_rms
));
// rmsnorm computation
// rmsnorm computation
auto
y
=
make_static_distributed_tensor
<
Y
DataType
>
(
x
.
get_tile_distribution
());
auto
rmsn
=
make_static_distributed_tensor
<
Compute
DataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
rmsn_
=
acc
[
idx
]
*
inv_rms_
[
i_idx
]
*
gamma_
;
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
rmsn
(
idx
)
=
rmsn_
;
});
});
store_tile
(
y_window
,
y
);
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
sm_scale_window_
,
y_scale_window_
,
rmsn
,
smem
);
}
else
if
constexpr
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
)
{
Epilogue
{}(
y_window_
,
y_scale_window_
,
rmsn
,
smem
);
}
else
{
Epilogue
{}(
y_window_
,
rmsn
);
}
}
}
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_problem.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,25 +12,25 @@ template <typename XDataType_,
...
@@ -12,25 +12,25 @@ template <typename XDataType_,
typename
ComputeDataType_
,
typename
ComputeDataType_
,
typename
YDataType_
,
typename
YDataType_
,
typename
InvRmsDataType_
,
typename
InvRmsDataType_
,
typename
SmoothScaleDataType_
,
typename
YScaleDataType_
,
typename
BlockShape_
,
typename
BlockShape_
,
bool
kPadN_
,
typename
Traits_
>
bool
kSaveInvRms_
,
bool
kTwoPass_
>
struct
Rmsnorm2dFwdPipelineProblem
struct
Rmsnorm2dFwdPipelineProblem
{
{
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
YDataType
=
remove_cvref_t
<
YDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
InvRmsDataType
=
remove_cvref_t
<
InvRmsDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
using
SmoothScaleDataType
=
remove_cvref_t
<
SmoothScaleDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
using
Traits
=
remove_cvref_t
<
Traits_
>
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_two_pass.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -22,12 +22,17 @@ struct Rmsnorm2dFwdPipelineTwoPass
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
YDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
InvRmsDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InvRmsDataType
>
;
using
XResidualDataType
=
XDataType
;
using
YResidualDataType
=
XDataType
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
kSaveInvRms
;
static
constexpr
bool
kSaveInvRms
=
Problem
::
Traits
::
kSaveInvRms
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockRmsnorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
const
char
*
name
=
[]()
{
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
if
constexpr
(
kNeedCrossWarpSync
)
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -41,19 +46,36 @@ struct Rmsnorm2dFwdPipelineTwoPass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
InvRmsWindow
>
template
<
typename
XWindow
,
typename
XResidualWindow
,
typename
GammaWindow
,
typename
YWindow
,
typename
YResidualWindow
,
typename
InvRmsWindow
,
typename
SmoothScaleWindow
,
typename
YScaleWindow
,
typename
Epilogue
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
XResidualWindow
&
x_residual_window_
,
const
GammaWindow
&
gamma_window_
,
const
GammaWindow
&
gamma_window_
,
YWindow
&
y_window
,
YWindow
&
y_window
,
const
YResidualWindow
&
y_residual_window_
,
InvRmsWindow
&
inv_rms_window
,
InvRmsWindow
&
inv_rms_window
,
const
SmoothScaleWindow
&
/*sm_scale_window_*/
,
YScaleWindow
&
/*y_scale_window*/
,
ComputeDataType
epsilon
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
void
*
smem
,
Epilogue
)
const
{
{
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
gamma_window
=
make_tile_window
(
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
x_residual_window
=
make_tile_window
(
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
// Problem::BlockShape
// Problem::BlockShape
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -67,15 +89,34 @@ struct Rmsnorm2dFwdPipelineTwoPass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
using
X
TensorType
=
decltype
(
load_tile
(
x_window
));
using
Compute
TensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
))
)
;
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
X
TensorType
>();
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
Compute
TensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
block_reduce2d
(
x
,
square_sum
,
reduce_square_sum_func
);
auto
x_resi
=
load_tile
(
x_residual_window
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
));
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
block_reduce2d
(
acc
,
square_sum
,
reduce_square_sum_func
);
}
}
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
...
@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
...
@@ -96,33 +137,47 @@ struct Rmsnorm2dFwdPipelineTwoPass
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
move_tile_window
(
y_window
,
{
0
,
stride_to_right_most_window
});
// rmsnorm computation
// rmsnorm computation
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
{
const
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
// load gamma/beta (TODO: support no gamma/beta?)
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Rmsnorm2dFusedAddEnum
::
PRE_ADD
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
});
}
// load gamma (TODO: support no gamma?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
YDataType
>
(
x
.
get_tile_distribution
());
// rmsnorm computation
auto
rmsn
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
sweep_tile
(
rmsn
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
rmsn_
=
acc
(
idx
)
*
inv_rms_
[
i_idx
]
*
gamma_
;
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
YDataType
>
(
y_
)
;
rmsn
(
idx
)
=
rmsn_
;
});
});
store_tile
(
y_window
,
y
);
static_assert
(
kFusedQuant
==
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
);
Epilogue
{}(
y_window
,
rmsn
);
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
gamma_window
,
{
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
move_tile_window
(
y_window
,
{
0
,
-
Block_N
});
}
}
...
...
include/ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_traits.hpp
0 → 100644
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
enum
class
Rmsnorm2dFusedAddEnum
{
NO_ADD
=
0
,
// fused add before RMSNorm and store result to global
PRE_ADD_STORE
=
1
,
// fused add before RMSNorm, but not store result
PRE_ADD
=
2
,
};
// clang-format off
template
<
Rmsnorm2dFusedAddEnum
>
struct
Rmsnorm2dFusedAddEnumName
;
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
NO_ADD
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
PRE_ADD_STORE
>
{
static
constexpr
const
char
*
name
=
"pras"
;
};
template
<
>
struct
Rmsnorm2dFusedAddEnumName
<
Rmsnorm2dFusedAddEnum
::
PRE_ADD
>
{
static
constexpr
const
char
*
name
=
"pra"
;
};
// clang-format on
enum
class
Rmsnorm2dFusedQuantEnum
{
NO_SWEEP
=
0
,
SMOOTH_DYNAMIC_QUANT
=
1
,
// smooth oulier + rowwise quant, need input x-scale and store y_scale
DYNAMIC_QUANT
=
2
,
// rowwise quant, store out a y-scale
};
// clang-format off
template
<
Rmsnorm2dFusedQuantEnum
>
struct
Rmsnorm2dFusedQuantEnumName
;
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
NO_SWEEP
>
{
static
constexpr
const
char
*
name
=
"no"
;
};
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"dqt"
;
};
template
<
>
struct
Rmsnorm2dFusedQuantEnumName
<
Rmsnorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
>
{
static
constexpr
const
char
*
name
=
"smdqt"
;
};
// clang-format on
template
<
bool
kPadN_
,
bool
kSaveInvRms_
,
bool
kTwoPass_
,
Rmsnorm2dFusedAddEnum
kFusedAdd_
,
Rmsnorm2dFusedQuantEnum
kFusedQuant_
>
struct
Rmsnorm2dFwdTraits
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveInvRms
=
kSaveInvRms_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Rmsnorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Rmsnorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
};
}
// namespace ck_tile
include/ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -12,7 +12,7 @@ namespace ck_tile {
...
@@ -12,7 +12,7 @@ namespace ck_tile {
struct
MoeSmoothquantHostArgs
struct
MoeSmoothquantHostArgs
{
{
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_
x
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_topk_ids
;
// [tokens, topk]
const
void
*
p_topk_ids
;
// [tokens, topk]
void
*
p_yscale
;
// [topk * tokens, 1], output, rowwise quant scale
void
*
p_yscale
;
// [topk * tokens, 1], output, rowwise quant scale
...
@@ -33,11 +33,11 @@ struct MoeSmoothquant
...
@@ -33,11 +33,11 @@ struct MoeSmoothquant
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
@@ -57,7 +57,7 @@ struct MoeSmoothquant
...
@@ -57,7 +57,7 @@ struct MoeSmoothquant
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_x
;
// [tokens ,hidden_size], input, fp16/bf16
const
void
*
p_
x
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [experts, hidden_size], input, columnwise scale, fp32
const
void
*
p_topk_ids
;
// [tokens, topk]
const
void
*
p_topk_ids
;
// [tokens, topk]
void
*
p_yscale
;
// [topk, tokens, 1], output, rowwise quant scale
void
*
p_yscale
;
// [topk, tokens, 1], output, rowwise quant scale
...
@@ -75,7 +75,7 @@ struct MoeSmoothquant
...
@@ -75,7 +75,7 @@ struct MoeSmoothquant
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_
x
scale
,
hargs
.
p_
sm
scale
,
hargs
.
p_topk_ids
,
hargs
.
p_topk_ids
,
hargs
.
p_yscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
p_qy
,
...
@@ -101,6 +101,7 @@ struct MoeSmoothquant
...
@@ -101,6 +101,7 @@ struct MoeSmoothquant
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
int8_t
>
{
static
constexpr
const
char
*
name
=
"i8"
;
};
// clang-format on
// clang-format on
// in byte
// in byte
...
@@ -118,7 +119,7 @@ struct MoeSmoothquant
...
@@ -118,7 +119,7 @@ struct MoeSmoothquant
#define _SS_ std::string
#define _SS_ std::string
#define _TS_ std::to_string
#define _TS_ std::to_string
return
_SS_
(
"moe_smoothquant_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
return
_SS_
(
"moe_smoothquant_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_SS_
(
t2s
<
QYDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
@@ -153,9 +154,10 @@ struct MoeSmoothquant
...
@@ -153,9 +154,10 @@ struct MoeSmoothquant
}();
}();
// [experts, hidden_size],
// [experts, hidden_size],
const
auto
x
scale_window
=
[
&
]()
{
const
auto
sm
scale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XScaleDataType
*>
(
kargs
.
p_xscale
)
+
i_expert
*
kargs
.
hidden_size
,
static_cast
<
const
SmoothScaleDataType
*>
(
kargs
.
p_smscale
)
+
i_expert
*
kargs
.
hidden_size
,
make_tuple
(
kargs
.
hidden_size
),
make_tuple
(
kargs
.
hidden_size
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
...
@@ -198,7 +200,7 @@ struct MoeSmoothquant
...
@@ -198,7 +200,7 @@ struct MoeSmoothquant
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x
scale_window
,
yscale_window
,
qy_window
,
kargs
.
hidden_size
,
smem
);
Pipeline
{}(
x_window
,
sm
scale_window
,
yscale_window
,
qy_window
,
kargs
.
hidden_size
,
smem
);
}
}
};
};
...
...
include/ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -11,11 +11,11 @@ namespace ck_tile {
...
@@ -11,11 +11,11 @@ namespace ck_tile {
// host side args
// host side args
struct
SmoothquantHostArgs
struct
SmoothquantHostArgs
{
{
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_x
;
// [m ,n], input, fp16/bf16
const
void
*
p_
x
scale
;
// [1, n], input, columnwise scale, fp32
const
void
*
p_
sm
scale
;
// [1, n], input, columnwise scale, fp32
void
*
p_yscale
;
// [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_
x
scale)
void
*
p_yscale
;
// [m, 1], output, rowwise quant scale (amax / 127) of (p_x * p_
sm
scale)
void
*
p_qy
;
// [m, n], output, p_x * p_
x
scale / p_yscale
void
*
p_qy
;
// [m, n], output, p_x * p_
sm
scale / p_yscale
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
...
@@ -30,11 +30,11 @@ struct Smoothquant
...
@@ -30,11 +30,11 @@ struct Smoothquant
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
...
@@ -52,7 +52,7 @@ struct Smoothquant
...
@@ -52,7 +52,7 @@ struct Smoothquant
struct
Kargs
struct
Kargs
{
{
const
void
*
p_x
;
const
void
*
p_x
;
const
void
*
p_
x
scale
;
const
void
*
p_
sm
scale
;
void
*
p_yscale
;
void
*
p_yscale
;
void
*
p_qy
;
void
*
p_qy
;
...
@@ -67,7 +67,7 @@ struct Smoothquant
...
@@ -67,7 +67,7 @@ struct Smoothquant
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
{
return
Kargs
{
hargs
.
p_x
,
return
Kargs
{
hargs
.
p_x
,
hargs
.
p_
x
scale
,
hargs
.
p_
sm
scale
,
hargs
.
p_yscale
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
p_qy
,
hargs
.
m
,
hargs
.
m
,
...
@@ -134,9 +134,9 @@ struct Smoothquant
...
@@ -134,9 +134,9 @@ struct Smoothquant
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
}();
const
auto
x
scale_window
=
[
&
]()
{
const
auto
sm
scale_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
X
ScaleDataType
*>
(
kargs
.
p_
x
scale
),
static_cast
<
const
Smooth
ScaleDataType
*>
(
kargs
.
p_
sm
scale
),
make_tuple
(
kargs
.
n
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
...
@@ -177,7 +177,7 @@ struct Smoothquant
...
@@ -177,7 +177,7 @@ struct Smoothquant
__shared__
char
smem
[
GetSmemSize
()];
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
x_window
,
x
scale_window
,
yscale_window
,
qy_window
,
kargs
.
n
,
smem
);
Pipeline
{}(
x_window
,
sm
scale_window
,
yscale_window
,
qy_window
,
kargs
.
n
,
smem
);
}
}
};
};
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
...
@@ -28,7 +28,7 @@ struct SmoothquantPipelineDefaultPolicy
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
Make
X
ScaleBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
Make
Smooth
ScaleBlockTileDistribution
()
{
{
using
S
=
typename
Problem
::
BlockShape
;
using
S
=
typename
Problem
::
BlockShape
;
...
...
include/ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp
View file @
dec32dc6
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
...
@@ -16,11 +16,11 @@ struct SmoothquantPipelineOnePass
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
X
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
X
ScaleDataType
>
;
using
Smooth
ScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
Smooth
ScaleDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockSmoothquantProblem::kPadM
...
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
...
@@ -39,9 +39,12 @@ struct SmoothquantPipelineOnePass
return
Policy
::
template
GetSmemSize
<
Problem
>();
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
}
template
<
typename
XWindow
,
typename
XScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
template
<
typename
XWindow
,
typename
SmoothScaleWindow
,
typename
QYWindow
,
typename
YScaleWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
CK_TILE_DEVICE
auto
operator
()(
const
XWindow
&
x_window_
,
const
X
ScaleWindow
&
x
scale_window_
,
const
Smooth
ScaleWindow
&
sm
scale_window_
,
YScaleWindow
&
yscale_window
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
QYWindow
&
qy_window
,
ck_tile
::
index_t
,
ck_tile
::
index_t
,
...
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
...
@@ -49,8 +52,8 @@ struct SmoothquantPipelineOnePass
{
{
auto
x_window
=
auto
x_window
=
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
make_tile_window
(
x_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
x
scale_window
=
make_tile_window
(
auto
sm
scale_window
=
make_tile_window
(
x
scale_window_
,
Policy
::
template
Make
X
ScaleBlockTileDistribution
<
Problem
>());
sm
scale_window_
,
Policy
::
template
Make
Smooth
ScaleBlockTileDistribution
<
Problem
>());
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
...
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
...
@@ -67,14 +70,14 @@ struct SmoothquantPipelineOnePass
auto
block_reduce2d_cross_warp_sync
=
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
=
load_tile
(
x_window
);
const
auto
x
scale
=
load_tile
(
x
scale_window
);
const
auto
sm
scale
=
load_tile
(
sm
scale_window
);
auto
y
=
tile_elementwise_in
(
auto
y
=
tile_elementwise_in
(
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
[
&
](
const
auto
&
a
,
const
auto
&
b
)
{
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
return
type_convert
<
ComputeDataType
>
(
a
)
*
type_convert
<
ComputeDataType
>
(
b
);
},
},
x
,
x
,
x
scale
);
sm
scale
);
// compute absmax, cross-lane->cross-warp
// compute absmax, cross-lane->cross-warp
auto
absmax
=
[
&
]()
{
auto
absmax
=
[
&
]()
{
...
@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass
...
@@ -110,7 +113,7 @@ struct SmoothquantPipelineOnePass
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
sweep_tile
(
qy
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
auto
qy_
=
y
[
idx
]
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
qy
(
idx
)
=
type_convert
<
QYDataType
>
(
saturates
<
QYDataType
>
{}(
qy_
)
)
;
});
});
store_tile
(
qy_window
,
qy
);
store_tile
(
qy_window
,
qy
);
}
}
...
...
Prev
1
…
5
6
7
8
9
10
11
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment