Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
fb9f0757
Commit
fb9f0757
authored
Nov 19, 2024
by
Jun Liu
Browse files
Merge branch 'amd-develop' into amd-master
parents
4d914af3
c5ad2e80
Changes
183
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1145 additions
and
256 deletions
+1145
-256
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
...a/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
+2
-0
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
+232
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
...e/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
+39
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
...ude/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
+15
-0
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
...de/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
+23
-0
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
+50
-20
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
.../ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
+3
-3
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
...e/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
+44
-19
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
...line/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
+267
-63
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
+116
-38
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
...gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
+264
-52
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
+10
-6
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
...ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
+18
-9
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
...rm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
+14
-9
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
+20
-24
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
...layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
+1
-1
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
...ayernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
+11
-10
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
..._tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
+3
-1
include/ck_tile/ops/moe_sorting.hpp
include/ck_tile/ops/moe_sorting.hpp
+11
-0
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
+2
-1
No files found.
include/ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp
View file @
fb9f0757
...
@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
...
@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
static_assert
(
N0
!=
0
);
static_assert
(
N0
!=
0
);
...
...
include/ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp
0 → 100644
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
MoeSortingHostArgs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
void
*
p_moe_buf
;
index_t
tokens
;
index_t
unit_size
;
index_t
num_experts
;
index_t
topk
;
index_t
moe_buf_bytes
;
};
template
<
typename
Problem_
>
struct
MoeSortingKernel
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
IndexType
=
typename
Problem
::
IndexType
;
using
WeightType
=
typename
Problem
::
WeightType
;
typedef
MoeSortingHostArgs
MoeSortingKargs
;
using
Hargs
=
MoeSortingHostArgs
;
struct
Kargs
{
const
void
*
p_topk_ids
;
const
void
*
p_weights
;
void
*
p_sorted_token_ids
;
void
*
p_sorted_weights
;
void
*
p_sorted_expert_ids
;
void
*
p_total_tokens_post_pad
;
void
*
p_moe_buf
;
index_t
tokens
;
index_t
num_experts
;
index_t
moe_buf_bytes
;
index_t
tokens_per_thread
;
mdiv
unit_size_mdiv
;
mdiv
topk_mdiv
;
};
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
// TODO: assume num-experts not too much
return
dim3
(
1
+
ck_tile
::
integer_divide_ceil
(
h
.
moe_buf_bytes
,
BlockSize
(
h
).
x
*
16
));
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
(
const
Hargs
&
h
)
{
return
dim3
(
ck_tile
::
integer_least_multiple
(
h
.
num_experts
,
ck_tile
::
get_warp_size
()));
}
// in byte
CK_TILE_HOST
static
constexpr
auto
GetSmemSize
(
const
Hargs
&
h
)
{
const
auto
blocks
=
BlockSize
(
h
);
return
((
blocks
.
x
+
1
)
*
h
.
num_experts
+
(
h
.
num_experts
+
1
))
*
sizeof
(
index_t
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_topk_ids
=
h
.
p_topk_ids
;
k
.
p_weights
=
h
.
p_weights
;
k
.
p_sorted_token_ids
=
h
.
p_sorted_token_ids
;
k
.
p_sorted_weights
=
h
.
p_sorted_weights
;
k
.
p_sorted_expert_ids
=
h
.
p_sorted_expert_ids
;
k
.
p_moe_buf
=
h
.
p_moe_buf
;
k
.
p_total_tokens_post_pad
=
h
.
p_total_tokens_post_pad
;
k
.
tokens
=
h
.
tokens
;
k
.
num_experts
=
h
.
num_experts
;
k
.
moe_buf_bytes
=
h
.
moe_buf_bytes
;
const
auto
blocks
=
BlockSize
(
h
);
k
.
tokens_per_thread
=
integer_divide_ceil
(
h
.
tokens
*
h
.
topk
,
blocks
.
x
);
k
.
unit_size_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
unit_size
)};
k
.
topk_mdiv
=
mdiv
{
static_cast
<
uint32_t
>
(
h
.
topk
)};
return
k
;
}
CK_TILE_DEVICE
index_t
calc_index
(
index_t
total_col
,
index_t
row
,
index_t
col
)
const
{
return
row
*
total_col
+
col
;
}
CK_TILE_DEVICE
void
moe_buf_set_zero_kernel
(
uint8x16_t
*
buf
,
index_t
buf_bytes
)
const
{
const
index_t
offset
=
(
blockIdx
.
x
-
1
)
*
blockDim
.
x
+
threadIdx
.
x
;
if
(
offset
<
buf_bytes
/
16
)
{
buf
[
offset
]
=
uint8x16_t
{
0
};
}
}
CK_TILE_DEVICE
void
moe_align_block_size_kernel
(
const
IndexType
*
__restrict__
topk_id
,
const
WeightType
*
__restrict__
weights
,
index_t
*
p_sorted_token_ids
,
WeightType
*
p_sorted_weights
,
index_t
*
p_sorted_expert_ids
,
index_t
*
p_total_tokens_post_pad
,
const
index_t
num_experts
,
const
index_t
tokens_per_thread
,
const
index_t
numel
,
const
mdiv
unit_size_mdiv
,
const
mdiv
topk_mdiv
,
void
*
smem
)
const
{
const
index_t
tid
=
static_cast
<
index_t
>
(
threadIdx
.
x
);
const
index_t
start_idx
=
tid
*
tokens_per_thread
;
index_t
*
shared_mem
=
reinterpret_cast
<
index_t
*>
(
smem
);
index_t
*
tokens_cnts
=
shared_mem
;
// 2d: (blockDim.x + 1, num_experts)
index_t
*
cumsum
=
shared_mem
+
(
blockDim
.
x
+
1
)
*
num_experts
;
// 1: (num_experts + 1)
for
(
int
i
=
0
;
i
<
num_experts
;
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
i
)]
=
0
;
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
+
1
,
topk_id
[
i
])];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
tokens_cnts
[
calc_index
(
num_experts
,
0
,
tid
)]
=
0
;
for
(
int
i
=
1
;
i
<=
static_cast
<
index_t
>
(
blockDim
.
x
);
++
i
)
{
tokens_cnts
[
calc_index
(
num_experts
,
i
,
tid
)]
+=
tokens_cnts
[
calc_index
(
num_experts
,
i
-
1
,
tid
)];
}
}
// __syncthreads();
if
(
tid
==
0
)
{
cumsum
[
0
]
=
0
;
for
(
int
i
=
1
;
i
<=
num_experts
;
++
i
)
{
auto
current_units
=
[
&
]()
{
index_t
x_
=
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
i
-
1
)]
+
unit_size_mdiv
.
divisor
-
1
;
index_t
y_
=
unit_size_mdiv
.
div
(
x_
);
return
max
(
y_
,
1
)
*
unit_size_mdiv
.
divisor
;
}();
cumsum
[
i
]
=
cumsum
[
i
-
1
]
+
current_units
;
}
*
p_total_tokens_post_pad
=
cumsum
[
num_experts
];
}
__syncthreads
();
if
(
tid
<
num_experts
)
{
for
(
int
i
=
cumsum
[
tid
];
i
<
cumsum
[
tid
+
1
];
i
+=
unit_size_mdiv
.
divisor
)
{
p_sorted_expert_ids
[
unit_size_mdiv
.
div
(
i
)]
=
tid
;
}
}
#pragma unroll Problem_::InternalLoadUnroll
for
(
int
i
=
start_idx
;
i
<
numel
&&
i
<
start_idx
+
tokens_per_thread
;
++
i
)
{
index_t
expert_id
=
topk_id
[
i
];
index_t
rank_post_pad
=
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)]
+
cumsum
[
expert_id
];
p_sorted_token_ids
[
rank_post_pad
]
=
topk_mdiv
.
div
(
i
);
p_sorted_weights
[
rank_post_pad
]
=
weights
[
i
];
++
tokens_cnts
[
calc_index
(
num_experts
,
tid
,
expert_id
)];
}
const
index_t
prefill_token
=
topk_mdiv
.
div
(
numel
);
if
(
tid
<
num_experts
)
{
index_t
expert_offset
=
cumsum
[
tid
]
+
tokens_cnts
[
calc_index
(
num_experts
,
blockDim
.
x
,
tid
)];
while
(
expert_offset
<
cumsum
[
tid
+
1
])
{
p_sorted_token_ids
[
expert_offset
]
=
prefill_token
;
p_sorted_weights
[
expert_offset
]
=
static_cast
<
WeightType
>
(
0.0
);
expert_offset
++
;
}
}
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
if
(
blockIdx
.
x
>
0
)
{
if
(
kargs
.
p_moe_buf
)
{
moe_buf_set_zero_kernel
(
reinterpret_cast
<
uint8x16_t
*>
(
kargs
.
p_moe_buf
),
kargs
.
moe_buf_bytes
);
}
return
;
}
const
size_t
numel
=
kargs
.
tokens
*
kargs
.
topk_mdiv
.
divisor
;
extern
__shared__
char
smem
[];
return
moe_align_block_size_kernel
(
static_cast
<
const
IndexType
*>
(
kargs
.
p_topk_ids
),
static_cast
<
const
WeightType
*>
(
kargs
.
p_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_token_ids
),
static_cast
<
WeightType
*>
(
kargs
.
p_sorted_weights
),
static_cast
<
IndexType
*>
(
kargs
.
p_sorted_expert_ids
),
static_cast
<
IndexType
*>
(
kargs
.
p_total_tokens_post_pad
),
kargs
.
num_experts
,
kargs
.
tokens_per_thread
,
numel
,
kargs
.
unit_size_mdiv
,
kargs
.
topk_mdiv
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp
0 → 100644
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
// template <typename Problem_, typename Policy_ = MoeSortingPolicy>
// struct MoeSortingPipeline
// {
// // TODO: this kernel only support warp per row
// using Problem = remove_cvref_t<Problem_>;
// using Policy = remove_cvref_t<Policy_>;
// using WeightType = typename Problem::WeightType;
// template <typename TopkIdWindow, typename WeightWindow>
// CK_TILE_DEVICE auto operator()(const TopkIdWindow& topk_id_window,
// const WeightWindow& weight_window,
// index_t* p_sorted_token_ids,
// WeightType* p_sorted_weights,
// index_t* p_sorted_expert_ids,
// index_t* p_total_tokens_post_pad,
// const index_t num_experts,
// const index_t unit_size,
// const size_t numel,
// const index_t topk)
// {
// }
// };
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp
0 → 100644
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
MoeSortingPolicy
{
};
}
// namespace ck_tile
include/ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp
0 → 100644
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
IndexType_
,
typename
WeightType_
,
index_t
InternalLoadUnroll_
>
struct
MoeSortingProblem
{
// TODO: this kernel only support warp per row
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static
constexpr
index_t
WarpsPerBlock
=
1
;
static
constexpr
index_t
InternalLoadUnroll
=
InternalLoadUnroll_
;
};
}
// namespace ck_tile
include/ck_tile/ops/gemm/kernel/gemm_kernel.hpp
View file @
fb9f0757
...
@@ -115,12 +115,22 @@ struct GemmKernel
...
@@ -115,12 +115,22 @@ struct GemmKernel
}
}
}();
}();
auto
a_pad_view
=
pad_tensor_view
(
auto
a_pad_view
=
[
&
]()
{
a_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
// somehow clang-format is splitting below line into multiple.
return
pad_tensor_view
(
// clang-format off
a_tensor_view
,
sequence
<
false
,
GemmPipeline
::
kPadA
>
{});
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
a_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
// clang-format on
// clang-format on
auto
a_block_window
=
make_tile_window
(
auto
a_block_window
=
make_tile_window
(
...
@@ -128,12 +138,22 @@ struct GemmKernel
...
@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
i_m
,
0
});
{
i_m
,
0
});
auto
b_pad_view
=
pad_tensor_view
(
auto
b_pad_view
=
[
&
]()
{
b_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
{
// clang-format off
return
pad_tensor_view
(
sequence
<
false
,
GemmPipeline
::
kPadB
>
{});
b_tensor_view
,
// clang-format on
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
false
,
GemmPipeline
::
kPadK
>
{});
}
else
{
return
pad_tensor_view
(
b_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kN
>
{},
number
<
TilePartitioner
::
kK
>
{}),
sequence
<
GemmPipeline
::
kPadN
,
false
>
{});
}
}();
auto
b_block_window
=
make_tile_window
(
auto
b_block_window
=
make_tile_window
(
b_pad_view
,
b_pad_view
,
...
@@ -171,18 +191,28 @@ struct GemmKernel
...
@@ -171,18 +191,28 @@ struct GemmKernel
}
}
}();
}();
auto
c_pad_view
=
pad_tensor_view
(
auto
c_pad_view
=
[
&
]()
{
c_tensor_view
,
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
// clang-format off
return
pad_tensor_view
(
sequence
<
false
,
GemmPipeline
::
kPadC
>
{});
c_tensor_view
,
// clang-format on
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
auto
c_block_window
=
make_tile_window
(
sequence
<
false
,
GemmPipeline
::
kPadN
>
{});
}
else
{
return
pad_tensor_view
(
c_tensor_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
sequence
<
GemmPipeline
::
kPadM
,
false
>
{});
}
}();
auto
CBlockWindow_pad
=
make_tile_window
(
c_pad_view
,
c_pad_view
,
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
make_tuple
(
number
<
TilePartitioner
::
kM
>
{},
number
<
TilePartitioner
::
kN
>
{}),
{
i_m
,
i_n
});
{
i_m
,
i_n
});
EpiloguePipeline
{}(
c_b
lock
_w
indow
,
c_block_tile
);
EpiloguePipeline
{}(
CB
lock
W
indow
_pad
,
c_block_tile
);
}
}
};
};
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp
View file @
fb9f0757
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
...
@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
// Where is the right place for HasHotLoop and TailNum ???
// Where is the right place for HasHotLoop and TailNum ???
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
static
constexpr
bool
HasHotLoop
=
Problem
::
HasHotLoop
;
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp
View file @
fb9f0757
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeB
=
Problem
::
VectorSizeB
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
index_t
VectorSizeC
=
Problem
::
VectorSizeC
;
static
constexpr
bool
kPad
A
=
Problem
::
kPad
A
;
static
constexpr
bool
kPad
M
=
Problem
::
kPad
M
;
static
constexpr
bool
kPad
B
=
Problem
::
kPad
B
;
static
constexpr
bool
kPad
N
=
Problem
::
kPad
N
;
static
constexpr
bool
kPad
C
=
Problem
::
kPad
C
;
static
constexpr
bool
kPad
K
=
Problem
::
kPad
K
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetStaticLdsSize
()
{
{
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
Policy
::
template
MakeADramTileDistribution
<
Problem
>());
// A LDS tile window for store
// A LDS tile window for store
auto
a_copy_lds_window
=
auto
a_copy_lds_window
=
make_tile_window
(
make_tile_window
(
a_lds_block
,
a_lds_block
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
a_copy_dram_window
.
get_tile_distribution
());
// B DRAM tile window for load
// B DRAM tile window for load
auto
b_copy_dram_window
=
auto
b_copy_dram_window
=
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
Policy
::
template
MakeBDramTileDistribution
<
Problem
>());
// B LDS tile window for store
// B LDS tile window for store
auto
b_copy_lds_window
=
auto
b_copy_lds_window
=
make_tile_window
(
make_tile_window
(
b_lds_block
,
b_lds_block
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
});
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kKPerBlock
>
{}),
{
0
,
0
},
b_copy_dram_window
.
get_tile_distribution
());
// A LDS tile for block GEMM
// A LDS tile for block GEMM
auto
a_lds_gemm_window
=
make_tile_window
(
auto
a_lds_gemm_window
=
make_tile_window
(
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
tile_elementwise_inout
([](
auto
&
c
)
{
c
=
0
;
},
c_block_tile
);
// LDS write 0
// LDS write 0
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_block_tile
);
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
{
auto
a_shuffle_tmp
=
make_static_distributed_tensor
<
ADataType
>
(
Policy
::
template
MakeShuffledARegBlockDescriptor
<
Problem
>());
shuffle_tile
(
a_shuffle_tmp
,
a_block_tile
);
const
auto
a_block_tile_tmp
=
tile_elementwise_in
(
a_element_func
,
a_shuffle_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
}
else
{
store_tile
(
a_copy_lds_window
,
tile_elementwise_in
(
a_element_func
,
a_block_tile
));
}
// LDS write 0
// LDS write 0
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
{
auto
b_shuffle_tmp
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp
,
b_block_tile
);
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
else
{
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_block_tile
));
}
}
}
index_t
iCounter
=
num_loop
-
1
;
index_t
iCounter
=
num_loop
-
1
;
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
...
@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
store_tile
(
a_copy_lds_window
,
a_block_tile_tmp
);
// LDS write i + 1
// LDS write i + 1
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
{
auto
b_shuffle_tmp_loop
=
make_static_distributed_tensor
<
BDataType
>
(
Policy
::
template
MakeShuffledBRegBlockDescriptor
<
Problem
>());
shuffle_tile
(
b_shuffle_tmp_loop
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
tile_elementwise_in
(
b_element_func
,
b_shuffle_tmp_loop
));
}
else
{
const
auto
b_block_tile_tmp
=
tile_elementwise_in
(
b_element_func
,
b_block_tile
);
store_tile
(
b_copy_lds_window
,
b_block_tile_tmp
);
}
iCounter
--
;
iCounter
--
;
}
}
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp
View file @
fb9f0757
...
@@ -11,6 +11,7 @@ namespace ck_tile {
...
@@ -11,6 +11,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead
// Default policy class should not be templated, put template on member functions instead
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
struct
GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{
{
#if 0
#if 0
// 2d
// 2d
template <typename Problem>
template <typename Problem>
...
@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return
smem_size
;
return
smem_size
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
#elif 1
#elif 1
// fake XOR
// fake XOR
template
<
typename
Problem
>
template
<
typename
Problem
>
...
@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
...
@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
{
#if 1 // coalesce reading for each blocks
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M0
=
MPerBlock
/
M1
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
M0
=
kMPerBlock
/
(
M2
*
M1
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
return
make_static_tile_distribution
(
static_assert
(
KPack
%
K3
==
0
);
tile_distribution_encoding
<
sequence
<
1
>
,
constexpr
index_t
K2
=
KPack
/
K3
;
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
))
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
sequence
<
1
,
2
>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
0
,
1
>>
{});
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
#else // coalesce reading for each warps
return
make_static_tile_distribution
(
constexpr
index_t
M0
=
kBlockSize
/
get_warp_size
();
tile_distribution_encoding
<
sequence
<
1
>
,
constexpr
index_t
M1
=
kMPerBlock
/
(
M2
*
M0
);
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
return
make_static_tile_distribution
(
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
tile_distribution_encoding
<
sequence
<
1
>
,
sequence
<
2
,
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
sequence
<
3
,
1
>>
{});
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
}
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
else
sequence
<
1
,
2
>
,
{
sequence
<
1
,
1
>>
{});
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
#endif
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
16
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
constexpr
index_t
K2
=
KPack
/
K3
;
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kNPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
16
/
sizeof
(
BDataType
);
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
kKPerBlock
/
K1
;
constexpr
index_t
N0
=
kNPerBlock
/
N1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
total_pixels
=
kNPerBlock
*
kKPerBlock
/
kBlockSize
;
#if 1 // coalesce reading for each blocks
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
K3
=
total_pixels
/
N1
;
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
N0
=
kNPerBlock
/
(
N2
*
N1
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
return
make_static_tile_distribution
(
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
tile_distribution_encoding
<
sequence
<
1
>
,
{
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
0
,
1
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
#else // coalesce reading for each warps
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
constexpr
index_t
N0
=
kBlockSize
/
get_warp_size
();
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
constexpr
index_t
N1
=
kNPerBlock
/
(
N2
*
N0
);
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
return
make_static_tile_distribution
(
sequence
<
1
,
3
>>
{});
tile_distribution_encoding
<
sequence
<
1
>
,
}
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
else
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
{
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
sequence
<
1
,
2
>
,
constexpr
index_t
K2_m
=
K2
/
K1
;
sequence
<
1
,
1
>>
{});
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
#endif
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
kMPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
kKPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
kMPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
kMPerBlock
*
kKPerBlock
/
kBlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
kBlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
kBlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
kKPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp
View file @
fb9f0757
...
@@ -3,40 +3,133 @@
...
@@ -3,40 +3,133 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
static
constexpr
int
_VectorSize
=
16
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
typename
TileGemmTraits_
>
struct
GemmPipelineProblem
struct
GemmPipelineProblem
Base
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
index_t
VectorLoadSize
=
GemmTraits
::
_VectorSize
;
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
bool
kPadM
=
GemmTraits
::
kPadM
;
static
constexpr
bool
kPadN
=
GemmTraits
::
kPadN
;
static
constexpr
bool
kPadK
=
GemmTraits
::
kPadK
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentA
()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kM
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
ADataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
ADataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
ADataType
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentB
()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
{
constexpr
index_t
pixels_per_thread
=
BlockGemmShape
::
kN
*
BlockGemmShape
::
kK
/
kBlockSize
;
return
pixels_per_thread
<
VectorLoadSize
/
sizeof
(
BDataType
)
?
pixels_per_thread
:
VectorLoadSize
/
sizeof
(
BDataType
);
}
else
{
return
VectorLoadSize
/
sizeof
(
BDataType
);
}
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetAlignmentC
()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
constexpr
index_t
N1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
N2
=
std
::
min
(
BlockGemmShape
::
kN
/
N1
,
get_warp_size
());
constexpr
index_t
M0
=
get_warp_size
()
/
N2
;
constexpr
index_t
M1
=
BlockGemmShape
::
kM
/
M0
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
1
:
_VectorSize
/
sizeof
(
ADataType
);
return
std
::
min
(
M1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
static
constexpr
index_t
VectorSizeB
=
kPadB
?
1
:
_VectorSize
/
sizeof
(
BDataType
);
}
static
constexpr
index_t
VectorSizeC
=
kPadC
?
1
:
_VectorSize
/
sizeof
(
CDataType
);
else
{
constexpr
index_t
M1
=
kBlockSize
/
get_warp_size
();
constexpr
index_t
M2
=
std
::
min
(
BlockGemmShape
::
kM
/
M1
,
get_warp_size
());
constexpr
index_t
N0
=
get_warp_size
()
/
M2
;
constexpr
index_t
N1
=
BlockGemmShape
::
kN
/
N0
;
return
std
::
min
(
N1
,
static_cast
<
index_t
>
(
VectorLoadSize
/
sizeof
(
CDataType
)));
}
}
static
constexpr
index_t
VectorSizeA
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
ALayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadK
?
1
:
GetAlignmentA
();
}
else
{
return
kPadM
?
1
:
GetAlignmentA
();
}
}();
static
constexpr
index_t
VectorSizeB
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
BLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentB
();
}
else
{
return
kPadK
?
1
:
GetAlignmentB
();
}
}();
static
constexpr
index_t
VectorSizeC
=
[]()
{
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
return
kPadN
?
1
:
GetAlignmentC
();
}
else
{
return
kPadM
?
1
:
GetAlignmentC
();
}
}();
};
};
// Alias for GemmPipelineProblem
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
BlockGemmShape_
,
typename
TileGemmTraits_
>
using
GemmPipelineProblem
=
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
;
template
<
typename
ADataType_
,
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
BDataType_
,
typename
CDataType_
,
typename
CDataType_
,
...
@@ -45,30 +138,15 @@ template <typename ADataType_,
...
@@ -45,30 +138,15 @@ template <typename ADataType_,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
GemmPipelineScheduler
Scheduler_
=
GemmPipelineScheduler
::
Intrawave
,
bool
HasHotLoop_
=
true
,
bool
HasHotLoop_
=
true
,
TailNumber
TailNum_
=
TailNumber
::
Full
>
TailNumber
TailNum_
=
TailNumber
::
Full
>
struct
UniversalGemmPipelineProblem
struct
UniversalGemmPipelineProblem
:
public
GemmPipelineProblemBase
<
ADataType_
,
BDataType_
,
CDataType_
,
BlockGemmShape_
,
TileGemmTraits_
>
{
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
using
CDataType
=
remove_cvref_t
<
CDataType_
>
;
static
constexpr
auto
TailNum
=
TailNum_
;
using
BlockGemmShape
=
remove_cvref_t
<
BlockGemmShape_
>
;
using
GemmTraits
=
remove_cvref_t
<
TileGemmTraits_
>
;
using
ALayout
=
remove_cvref_t
<
typename
GemmTraits
::
ALayout
>
;
using
BLayout
=
remove_cvref_t
<
typename
GemmTraits
::
BLayout
>
;
using
CLayout
=
remove_cvref_t
<
typename
GemmTraits
::
CLayout
>
;
static
constexpr
auto
Scheduler
=
Scheduler_
;
static
constexpr
auto
HasHotLoop
=
HasHotLoop_
;
static
constexpr
auto
TailNum
=
TailNum_
;
static
constexpr
index_t
kBlockSize
=
BlockGemmShape
::
NumWarps
*
get_warp_size
();
static
constexpr
bool
kPadA
=
GemmTraits
::
kPadA
;
static
constexpr
bool
kPadB
=
GemmTraits
::
kPadB
;
static
constexpr
bool
kPadC
=
GemmTraits
::
kPadC
;
static
constexpr
index_t
VectorSizeA
=
kPadA
?
_VectorSize
/
sizeof
(
ADataType
)
:
1
;
static
constexpr
index_t
VectorSizeB
=
kPadB
?
_VectorSize
/
sizeof
(
BDataType
)
:
1
;
static
constexpr
index_t
VectorSizeC
=
kPadC
?
_VectorSize
/
sizeof
(
CDataType
)
:
1
;
};
};
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp
View file @
fb9f0757
...
@@ -9,12 +9,8 @@
...
@@ -9,12 +9,8 @@
namespace
ck_tile
{
namespace
ck_tile
{
// UniversalGemm Policy
// UniversalGemm Policy
template
<
typename
LayoutA_
,
typename
LayoutB_
,
typename
LayoutC_
>
struct
UniversalGemmPipelineAgBgCrPolicy
struct
UniversalGemmPipelineAgBgCrPolicy
{
{
using
LayoutA
=
remove_cvref_t
<
LayoutA_
>
;
using
LayoutB
=
remove_cvref_t
<
LayoutB_
>
;
using
LayoutC
=
remove_cvref_t
<
LayoutC_
>
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
...
@@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy
TransposeC
>
;
TransposeC
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
Layout
A
>::
value
)
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
RowMajor
,
A
Layout
>::
value
)
{
{
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
constexpr
auto
MLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
ADataType
)
<
1
?
1
?
1
...
@@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
K0
=
KPerBlock
/
K1
;
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
Layout
B
>::
value
)
if
constexpr
(
std
::
is_same
<
tensor_layout
::
gemm
::
ColumnMajor
,
B
Layout
>::
value
)
{
{
// NLdsLayer * K0 as logical Bank
// NLdsLayer * K0 as logical Bank
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
constexpr
auto
NLdsLayer
=
32
*
4
/
KPerBlock
/
sizeof
(
BDataType
)
<
1
...
@@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy
...
@@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy
return
smem_size
;
return
smem_size
;
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackA
()
{
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSmemPackB
()
{
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
return
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeADramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
typename
Problem
::
BDataType
,
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kM
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
if
constexpr
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
)
constexpr
index_t
K0
=
KPerBlock
/
K1
;
{
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
M1
==
0
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
constexpr
index_t
KPack
=
GetSmemPackA
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
return
make_static_tile_distribution
(
constexpr
index_t
K2
=
KPack
/
K3
;
tile_distribution_encoding
<
sequence
<
1
>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
M0
)
==
0
)
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
M0
);
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
1
,
2
>
,
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
sequence
<
0
,
1
>>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
M2
=
get_warp_size
()
/
K0
;
if
constexpr
(
get_warp_size
()
%
(
M2
*
K0
)
==
0
)
{
constexpr
index_t
M1
=
BlockSize
/
get_warp_size
();
static_assert
(
M2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
M1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
M0
=
MPerBlock
/
(
M2
*
M1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
else
{
constexpr
index_t
M0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
M1
=
MPerBlock
/
(
M2
*
M0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
,
M2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeBDramTileDistribution
()
{
{
using
WarpGemm
=
WarpGemmMfmaDispatcher
<
typename
Problem
::
ADataType
,
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
typename
Problem
::
BDataType
,
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
typename
Problem
::
CDataType
,
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I0
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I1
),
Problem
::
BlockGemmShape
::
WarpTile
::
at
(
I2
),
TransposeC
>
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
K1
=
WarpGemm
::
kK
;
if
constexpr
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
)
constexpr
index_t
K0
=
KPerBlock
/
K1
;
{
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
N2
!=
0
,
"M2 is zero, which will lead to a division by zero error."
);
static_assert
(
total_pixels
%
N1
==
0
);
static_assert
(
N1
!=
0
,
"M1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
constexpr
index_t
KPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
KPack
%
K3
==
0
);
return
make_static_tile_distribution
(
constexpr
index_t
K2
=
KPack
/
K3
;
tile_distribution_encoding
<
sequence
<
1
>
,
if
constexpr
(
get_warp_size
()
%
(
K2
*
N0
)
==
0
)
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
{
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
constexpr
index_t
K1
=
get_warp_size
()
/
(
K2
*
N0
);
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
();
sequence
<
1
,
2
>
,
static_assert
(
KPerBlock
==
K0
*
K1
*
K2
*
K3
);
sequence
<
0
,
1
>>
{});
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
2
,
1
>
,
sequence
<
3
,
1
>>
{});
}
}
else
{
constexpr
index_t
K1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
K0
=
KPerBlock
/
K1
;
constexpr
index_t
N2
=
get_warp_size
()
/
K0
;
// coalesce reading for each blocks
if
constexpr
(
get_warp_size
()
%
(
N2
*
K0
)
==
0
)
{
constexpr
index_t
N1
=
BlockSize
/
get_warp_size
();
static_assert
(
N2
!=
0
,
"N2 is zero, which will lead to a division by zero error."
);
static_assert
(
N1
!=
0
,
"N1 is zero, which will lead to a division by zero error."
);
constexpr
index_t
N0
=
NPerBlock
/
(
N2
*
N1
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
1
>>
{});
}
// coalesce reading for each warps
else
{
constexpr
index_t
N0
=
BlockSize
/
get_warp_size
();
constexpr
index_t
N1
=
NPerBlock
/
(
N2
*
N0
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
,
N2
>
,
sequence
<
K0
,
K1
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
1
>>
{});
}
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledARegBlockDescriptor
()
{
using
ALayout
=
remove_cvref_t
<
typename
Problem
::
ALayout
>
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
static_assert
(
std
::
is_same_v
<
ALayout
,
ck_tile
::
tensor_layout
::
gemm
::
ColumnMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
MPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
M1
=
Problem
::
VectorLoadSize
/
sizeof
(
ADataType
);
constexpr
index_t
M0
=
MPerBlock
/
M1
;
constexpr
index_t
total_pixels
=
MPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
M1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
M1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
M0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
M0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
M0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
M0
,
M1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeShuffledBRegBlockDescriptor
()
{
using
BLayout
=
remove_cvref_t
<
typename
Problem
::
BLayout
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
static_assert
(
std
::
is_same_v
<
BLayout
,
ck_tile
::
tensor_layout
::
gemm
::
RowMajor
>
);
constexpr
index_t
BlockSize
=
Problem
::
kBlockSize
;
constexpr
index_t
NPerBlock
=
Problem
::
BlockGemmShape
::
kN
;
constexpr
index_t
KPerBlock
=
Problem
::
BlockGemmShape
::
kK
;
constexpr
index_t
N1
=
Problem
::
VectorLoadSize
/
sizeof
(
BDataType
);
constexpr
index_t
N0
=
NPerBlock
/
N1
;
constexpr
index_t
total_pixels
=
NPerBlock
*
KPerBlock
/
BlockSize
;
static_assert
(
total_pixels
%
N1
==
0
);
constexpr
index_t
K3
=
total_pixels
/
N1
;
constexpr
index_t
kKPack
=
GetSmemPackB
<
Problem
>
();
static_assert
(
kKPack
%
K3
==
0
);
constexpr
index_t
K2
=
kKPack
/
K3
;
// TODO: this dimention could be outside single wave
constexpr
index_t
warp_size
=
get_warp_size
();
if
constexpr
(
warp_size
%
(
K2
*
N0
)
==
0
)
{
constexpr
index_t
K1
=
warp_size
/
(
K2
*
N0
);
constexpr
index_t
K0
=
BlockSize
/
warp_size
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2
,
K3
>>
,
tuple
<
sequence
<
2
>
,
sequence
<
2
,
1
,
2
>>
,
tuple
<
sequence
<
0
>
,
sequence
<
1
,
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
else
{
constexpr
index_t
K1
=
(
K2
*
N0
)
/
get_warp_size
();
constexpr
index_t
K2_m
=
K2
/
K1
;
constexpr
index_t
K0
=
BlockSize
/
get_warp_size
()
/
K1
;
static_assert
(
KPerBlock
==
K0
*
K1
*
K2_m
*
K3
);
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
N0
,
N1
>
,
sequence
<
K0
,
K1
,
K2_m
,
K3
>>
,
tuple
<
sequence
<
2
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
2
>>
,
sequence
<
1
,
2
>
,
sequence
<
1
,
3
>>
{});
}
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
...
...
include/ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp
View file @
fb9f0757
...
@@ -3,19 +3,23 @@
...
@@ -3,19 +3,23 @@
#pragma once
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
template
<
bool
kPad
A
_
,
template
<
bool
kPad
M
_
,
bool
kPad
B
_
,
bool
kPad
N
_
,
bool
kPad
C
_
,
bool
kPad
K
_
,
typename
ALayout_
,
typename
ALayout_
,
typename
BLayout_
,
typename
BLayout_
,
typename
CLayout_
>
typename
CLayout_
>
struct
TileGemmTraits
struct
TileGemmTraits
{
{
static
constexpr
bool
kPadA
=
kPadA_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadB
=
kPadB_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadC
=
kPadC_
;
static
constexpr
bool
kPadK
=
kPadK_
;
static
constexpr
int
_VectorSize
=
16
;
using
ALayout
=
ALayout_
;
using
ALayout
=
ALayout_
;
using
BLayout
=
BLayout_
;
using
BLayout
=
BLayout_
;
...
...
include/ck_tile/ops/layernorm2d/kernel/layernorm2d_fwd_kernel.hpp
View file @
fb9f0757
...
@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
...
@@ -28,7 +28,10 @@ struct Layernorm2dFwdHostArgs
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
// TODO: Extract some type to wrapper class
// TODO: Extract some type to wrapper class
...
@@ -93,7 +96,10 @@ struct Layernorm2dFwd
...
@@ -93,7 +96,10 @@ struct Layernorm2dFwd
index_t
m
;
index_t
m
;
index_t
n
;
index_t
n
;
index_t
stride
;
// row_stride
index_t
x_stride
;
// x row_stride
index_t
xr_stride
;
// x residule row stride
index_t
y_stride
;
// y row stride
index_t
yr_stride
;
// y residule row stride
};
};
using
Hargs
=
Layernorm2dFwdHostArgs
;
using
Hargs
=
Layernorm2dFwdHostArgs
;
...
@@ -112,12 +118,15 @@ struct Layernorm2dFwd
...
@@ -112,12 +118,15 @@ struct Layernorm2dFwd
hargs
.
epsilon
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
m
,
hargs
.
n
,
hargs
.
n
,
hargs
.
stride
};
hargs
.
x_stride
,
hargs
.
xr_stride
,
hargs
.
y_stride
,
hargs
.
yr_stride
};
}
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
{
return
(
hargs
.
m
+
Block_M
-
1
)
/
Block_M
;
return
dim3
(
integer_divide_ceil
(
hargs
.
m
,
Block_M
))
;
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
...
@@ -165,7 +174,7 @@ struct Layernorm2dFwd
...
@@ -165,7 +174,7 @@ struct Layernorm2dFwd
return
base_str
;
return
base_str
;
}();
}();
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
return
_SS_
(
"layernorm2d_fwd_"
)
+
_SS_
(
prec_str
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
_SS_
(
Pipeline
::
name
)
+
surfix
;
...
@@ -182,7 +191,7 @@ struct Layernorm2dFwd
...
@@ -182,7 +191,7 @@ struct Layernorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
static_cast
<
const
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
x_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -201,7 +210,7 @@ struct Layernorm2dFwd
...
@@ -201,7 +210,7 @@ struct Layernorm2dFwd
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
static_cast
<
const
XResidualDataType
*>
(
kargs
.
p_x_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
xr_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -250,7 +259,7 @@ struct Layernorm2dFwd
...
@@ -250,7 +259,7 @@ struct Layernorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
static_cast
<
YDataType
*>
(
kargs
.
p_y
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
y_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
@@ -266,7 +275,7 @@ struct Layernorm2dFwd
...
@@ -266,7 +275,7 @@ struct Layernorm2dFwd
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
static_cast
<
YResidualDataType
*>
(
kargs
.
p_y_residual
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
make_tuple
(
kargs
.
yr_
stride
,
1
),
number
<
Vector_N
>
{},
number
<
Vector_N
>
{},
number
<
1
>
{});
number
<
1
>
{});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_default_policy.hpp
View file @
fb9f0757
...
@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -26,6 +26,7 @@ struct Layernorm2dFwdPipelineDefaultPolicy
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
sequence
<
0
,
3
,
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBetaBlockTileDistribution
()
{
{
...
@@ -44,9 +45,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -44,9 +45,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelford
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
return
BlockWelford
<
P_
>
{};
return
BlockWelford
<
P_
>
{};
}
}
...
@@ -54,9 +56,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -54,9 +56,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordSync
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
return
BlockWelfordSync
<
P_
>
{};
return
BlockWelfordSync
<
P_
>
{};
}
}
...
@@ -64,9 +67,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -64,9 +67,10 @@ struct Layernorm2dFwdPipelineDefaultPolicy
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockWelfordCrossWarpSync
()
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
return
BlockWelfordCrossWarpSync
<
P_
>
{};
return
BlockWelfordCrossWarpSync
<
P_
>
{};
}
}
...
@@ -76,13 +80,14 @@ struct Layernorm2dFwdPipelineDefaultPolicy
...
@@ -76,13 +80,14 @@ struct Layernorm2dFwdPipelineDefaultPolicy
{
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
{
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
X
DataType
,
using
P_
=
BlockWelfordProblem
<
typename
Problem
::
Compute
DataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
typename
Problem
::
BlockShape
,
Problem
::
Traits
::
kFastFDiv
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
block_welford
=
BlockWelford
<
P_
>
;
using
x_block_tile
=
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
X
DataType
>
(
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
Compute
DataType
>
(
MakeXBlockTileDistribution
<
Problem
>
()));
MakeXBlockTileDistribution
<
Problem
>
()));
using
mean_var_block_tile
=
using
mean_var_block_tile
=
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
decltype
(
block_welford
::
template
MakeMeanVarBlockTile
<
x_block_tile
>());
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_one_pass.hpp
View file @
fb9f0757
...
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -36,6 +36,7 @@ struct Layernorm2dFwdPipelineOnePass
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockLayernorm2dFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
kFastFDiv
=
Problem
::
Traits
::
kFastFDiv
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedAdd
=
Problem
::
Traits
::
kFusedAdd
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
static
constexpr
auto
kFusedQuant
=
Problem
::
Traits
::
kFusedQuant
;
...
@@ -87,12 +88,9 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -87,12 +88,9 @@ struct Layernorm2dFwdPipelineOnePass
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
x_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
auto
y_residual_window
=
make_tile_window
(
auto
y_residual_window
=
make_tile_window
(
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
y_residual_window_
,
Policy
::
template
MakeXBlockTileDistribution
<
Problem
>());
const
auto
x_scale_window
=
make_tile_window
(
x_scale_window_
,
Policy
::
template
MakeGammaBetaBlockTileDistribution
<
Problem
>());
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_scale
=
load_tile
(
x_scale_window
);
int
cur_count
=
0
;
int
cur_count
=
0
;
int
max_count
=
int
max_count
=
...
@@ -106,20 +104,21 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -106,20 +104,21 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
store_tile
(
y_residual_window
,
x
);
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
)
);
}
}
// compute welford each-thread->cross-lane->cross-warp
// compute welford each-thread->cross-lane->cross-warp
auto
[
mean
,
var
]
=
block_welford
(
x
,
cur_count
,
max_count
);
auto
[
mean
,
var
]
=
block_welford
(
acc
,
cur_count
,
max_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_welford_cross_warp_sync
(
mean
,
var
,
cur_count
,
smem
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
block_tile_welford_post_scale_var
(
var
,
cur_count
);
...
@@ -127,7 +126,15 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -127,7 +126,15 @@ struct Layernorm2dFwdPipelineOnePass
// compute inv-std
// compute inv-std
auto
inv_std
=
tile_elementwise_in
(
auto
inv_std
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
+
epsilon
));
if
(
kFastFDiv
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
*
__builtin_amdgcn_rcpf
(
sqrt
(
v_
+
epsilon
));
}
else
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
sqrt
(
v_
+
epsilon
);
}
},
},
var
);
var
);
...
@@ -137,7 +144,7 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -137,7 +144,7 @@ struct Layernorm2dFwdPipelineOnePass
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
store_tile
(
inv_std_window
,
cast_tile
<
InvStdDataType
>
(
inv_std
));
// layernorm computation
// layernorm computation
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
...
@@ -145,26 +152,15 @@ struct Layernorm2dFwdPipelineOnePass
...
@@ -145,26 +152,15 @@ struct Layernorm2dFwdPipelineOnePass
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
ln_
=
(
acc
[
idx
]
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
});
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
// smooth-quant pre-scale, then run rowwise-quant
sweep_tile
(
ln
,
[
&
](
auto
idx
)
{
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
xs_
=
type_convert
<
ComputeDataType
>
(
x_scale
[
j_idx
]);
ln
(
idx
)
=
ln
(
idx
)
*
xs_
;
});
}
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
if
constexpr
(
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
DYNAMIC_QUANT
||
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
kFusedQuant
==
Layernorm2dFusedQuantEnum
::
SMOOTH_DYNAMIC_QUANT
)
{
{
Epilogue
{}(
y_window_
,
y_scale_window
,
ln
,
smem
);
Epilogue
{}(
y_window_
,
x_scale_window_
,
y_scale_window
,
ln
,
smem
);
}
}
else
else
Epilogue
{}(
y_window_
,
ln
);
Epilogue
{}(
y_window_
,
ln
);
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_problem.hpp
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_pipeline_two_pass.hpp
View file @
fb9f0757
...
@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -106,7 +106,7 @@ struct Layernorm2dFwdPipelineTwoPass
auto
block_welford_cross_warp_sync
=
auto
block_welford_cross_warp_sync
=
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
Policy
::
template
GetBlockWelfordCrossWarpSync
<
Problem
>();
using
XTensorType
=
decltype
(
load_tile
(
x_window
));
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
x_window
))
)
;
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
mean
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
auto
var
=
block_welford
.
template
MakeMeanVarBlockTile
<
XTensorType
>();
...
@@ -117,21 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -117,21 +117,22 @@ struct Layernorm2dFwdPipelineTwoPass
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
x_residual_window
,
{
0
,
Block_N
});
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
});
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
)
{
{
store_tile
(
y_residual_window
,
x
);
store_tile
(
y_residual_window
,
cast_tile
<
YResidualDataType
>
(
acc
)
);
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
move_tile_window
(
y_residual_window
,
{
0
,
Block_N
});
}
}
}
}
block_welford
(
x
,
mean
,
var
,
cur_count
,
max_count
);
block_welford
(
acc
,
mean
,
var
,
cur_count
,
max_count
);
}
}
block_welford_sync
(
mean
,
var
,
cur_count
);
block_welford_sync
(
mean
,
var
,
cur_count
);
...
@@ -165,20 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -165,20 +166,21 @@ struct Layernorm2dFwdPipelineTwoPass
{
{
auto
x
=
load_tile
(
x_window
);
auto
x
=
load_tile
(
x_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
x_resi
=
load_tile
(
x_residual_window
);
auto
acc
=
cast_tile
<
ComputeDataType
>
(
x
);
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
if
constexpr
(
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD_STORE
||
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
kFusedAdd
==
Layernorm2dFusedAddEnum
::
PRE_ADD
)
{
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
sweep_tile
(
x_resi
,
[
&
](
auto
idx
)
{
// compute x = x_resi + x
// compute x = x_resi + x
x
(
idx
)
=
type_convert
<
YResidualDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
)
=
type_convert
<
ComputeDataType
>
(
x_resi
(
idx
))
+
acc
(
idx
);
type_convert
<
YResidualDataType
>
(
x
(
idx
));
});
});
}
}
// load gamma/beta (TODO: support no gamma/beta?)
// load gamma/beta (TODO: support no gamma/beta?)
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
const
auto
beta
=
load_tile
(
beta_window
);
const
auto
beta
=
load_tile
(
beta_window
);
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
ln
=
make_static_distributed_tensor
<
ComputeDataType
>
(
acc
.
get_tile_distribution
());
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
sweep_tile
(
ln
,
[
&
,
mean_
=
mean
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
...
@@ -187,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
...
@@ -187,8 +189,7 @@ struct Layernorm2dFwdPipelineTwoPass
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
beta_
=
type_convert
<
ComputeDataType
>
(
beta
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
ln_
=
(
acc
(
idx
)
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
auto
ln_
=
(
x_
-
mean_
[
i_idx
])
*
inv_std
[
i_idx
]
*
gamma_
+
beta_
;
ln
(
idx
)
=
ln_
;
ln
(
idx
)
=
ln_
;
});
});
...
...
include/ck_tile/ops/layernorm2d/pipeline/layernorm2d_fwd_traits.hpp
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -39,6 +39,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
...
@@ -39,6 +39,7 @@ template<> struct Layernorm2dFusedQuantEnumName<Layernorm2dFusedQuantEnum::SMOOT
template
<
bool
kPadN_
,
template
<
bool
kPadN_
,
bool
kSaveMeanInvStd_
,
bool
kSaveMeanInvStd_
,
bool
kFastFDiv_
,
bool
kTwoPass_
,
bool
kTwoPass_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedAddEnum
kFusedAdd_
,
Layernorm2dFusedQuantEnum
kFusedQuant_
>
Layernorm2dFusedQuantEnum
kFusedQuant_
>
...
@@ -46,6 +47,7 @@ struct Layernorm2dFwdTraits
...
@@ -46,6 +47,7 @@ struct Layernorm2dFwdTraits
{
{
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kSaveMeanInvStd
=
kSaveMeanInvStd_
;
static
constexpr
bool
kFastFDiv
=
kFastFDiv_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
bool
kTwoPass
=
kTwoPass_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedAddEnum
kFusedAdd
=
kFusedAdd_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
static
constexpr
Layernorm2dFusedQuantEnum
kFusedQuant
=
kFusedQuant_
;
...
...
include/ck_tile/ops/moe_sorting.hpp
0 → 100644
View file @
fb9f0757
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/reduce/block/block_reduce2d.hpp
View file @
fb9f0757
...
@@ -29,7 +29,8 @@ struct BlockReduce2d
...
@@ -29,7 +29,8 @@ struct BlockReduce2d
sweep_tile
<
XDistributedTensor_
>
(
sweep_tile
<
XDistributedTensor_
>
(
[
&
](
auto
...
idx_
)
{
[
&
](
auto
...
idx_
)
{
constexpr
auto
idx_0
=
make_tuple
(
make_tuple
(
idx_
[
number
<
0
>
{}]...)[
number
<
0
>
{}]);
constexpr
auto
idx_0
=
make_tuple
(
make_tuple
(
idx_
[
number
<
0
>
{}]...)[
number
<
0
>
{}]);
y_tensor
(
idx_0
)
=
reduce_func
(
y_tensor
(
idx_0
),
x_tensor
[
idx_
]...);
y_tensor
(
idx_0
)
=
reduce_func
(
y_tensor
(
idx_0
),
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_tensor
[
idx_
])...);
},
},
ReducePacksPerXDim
{});
ReducePacksPerXDim
{});
#if 0
#if 0
...
...
Prev
1
2
3
4
5
6
7
8
9
10
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment