Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ec959387
Unverified
Commit
ec959387
authored
Feb 13, 2025
by
rocking
Committed by
GitHub
Feb 13, 2025
Browse files
Merge branch 'develop' into ck_tile/fmha_receipt_aiter
parents
c1e2fef7
0e5e29c4
Changes
393
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
658 additions
and
162 deletions
+658
-162
include/ck_tile/host/reference/reference_batched_transpose.hpp
...de/ck_tile/host/reference/reference_batched_transpose.hpp
+59
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+3
-2
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-0
include/ck_tile/ops/batched_transpose.hpp
include/ck_tile/ops/batched_transpose.hpp
+12
-0
include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp
...ops/batched_transpose/kernel/batched_transpose_kernel.hpp
+129
-0
include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp
...batched_transpose/pipeline/batched_transpose_pipeline.hpp
+52
-0
include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp
...s/batched_transpose/pipeline/batched_transpose_policy.hpp
+44
-0
include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp
.../batched_transpose/pipeline/batched_transpose_problem.hpp
+48
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/utils.hpp
include/ck_tile/ops/common/utils.hpp
+34
-0
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+1
-0
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
+147
-151
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+97
-4
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+1
-0
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
.../block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
...ck/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
+1
-1
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
...tmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
+1
-1
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
No files found.
include/ck_tile/host/reference/reference_batched_transpose.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
Type
>
CK_TILE_HOST
void
reference_batched_transpose
(
const
HostTensor
<
Type
>&
x
,
HostTensor
<
Type
>&
y
,
std
::
string
layout_in
=
"NCHW"
,
std
::
string
layout_out
=
"NHWC"
)
{
const
int
N
=
x
.
mDesc
.
get_lengths
()[
0
];
auto
f
=
[
&
](
auto
batch
)
{
if
(
layout_in
==
"NCHW"
&&
layout_out
==
"NHWC"
)
{
const
int
C
=
x
.
mDesc
.
get_lengths
()[
1
];
const
int
H
=
x
.
mDesc
.
get_lengths
()[
2
];
const
int
W
=
x
.
mDesc
.
get_lengths
()[
3
];
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
Type
v_x
=
x
(
batch
,
c
,
h
,
w
);
y
(
batch
,
h
,
w
,
c
)
=
v_x
;
}
}
}
}
else
if
(
layout_in
==
"NHWC"
&&
layout_out
==
"NCHW"
)
{
const
int
H
=
x
.
mDesc
.
get_lengths
()[
1
];
const
int
W
=
x
.
mDesc
.
get_lengths
()[
2
];
const
int
C
=
x
.
mDesc
.
get_lengths
()[
3
];
for
(
int
h
=
0
;
h
<
H
;
++
h
)
{
for
(
int
w
=
0
;
w
<
W
;
++
w
)
{
for
(
int
c
=
0
;
c
<
C
;
++
c
)
{
Type
v_x
=
x
(
batch
,
h
,
w
,
c
);
y
(
batch
,
c
,
h
,
w
)
=
v_x
;
}
}
}
}
};
make_ParallelTensorFunctor
(
f
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
ec959387
...
...
@@ -80,13 +80,14 @@ __global__ void naive_gemm_kernel(ADataType* A,
int
b_index
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
col
*
strideB
+
k
:
k
*
strideB
+
col
;
acc
+=
static_cast
<
AccDataType
>
(
A
[
a_index
])
*
static_cast
<
AccDataType
>
(
B
[
b_index
]);
acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
A
[
a_index
])
*
ck_tile
::
type_convert
<
AccDataType
>
(
B
[
b_index
]);
}
int
c_index
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
row
*
strideC
+
col
:
col
*
strideC
+
row
;
C
[
c_index
]
=
acc
;
C
[
c_index
]
=
ck_tile
::
type_convert
<
CDataType
>
(
acc
)
;
}
}
...
...
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
ec959387
...
...
@@ -14,12 +14,15 @@ namespace ck_tile {
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
const
HostTensor
<
IndexType
>&
local_expert_mask
,
HostTensor
<
IndexType
>&
p_sorted_token_ids
,
HostTensor
<
WeightType
>&
sorted_weight
,
HostTensor
<
IndexType
>&
sorted_expert_ids
,
index_t
&
unit_cnt
,
const
index_t
experts
,
const
index_t
unit_size
)
const
index_t
unit_size
,
bool
local_expert_masking
,
bool
skip_experts_with_zero_token
=
true
)
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -33,8 +36,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
// count number of unit-size slices in this expert
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
// count the tokens used in this expert
std
::
vector
<
IndexType
>
expert_slice_idxs
(
experts
,
0
);
// TODO: above 2 buffer seems duplicated
for
(
index_t
t
=
0
;
t
<
num_token
;
t
++
)
{
...
...
@@ -72,8 +78,23 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
IndexType
*
out_tokens
=
p_sorted_token_ids
.
data
();
WeightType
*
out_weights
=
sorted_weight
.
data
();
IndexType
*
out_expert_id
=
sorted_expert_ids
.
data
();
int
curr_expert_id
=
0
;
for
(
index_t
e
=
0
;
e
<
experts
;
e
++
)
{
if
(
local_expert_masking
)
{
if
(
local_expert_mask
(
e
)
==
0
)
continue
;
}
if
(
skip_experts_with_zero_token
)
{
if
(
expert_slice_idxs
[
e
]
==
0
)
{
curr_expert_id
++
;
continue
;
}
}
memcpy
(
out_tokens
,
expert_tokens
[
e
].
data
(),
sizeof
(
index_t
)
*
expert_slices
[
e
]
*
unit_size
);
out_tokens
+=
expert_slices
[
e
]
*
unit_size
;
memcpy
(
out_weights
,
...
...
@@ -83,10 +104,11 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
for
(
index_t
s
=
0
;
s
<
expert_slices
[
e
];
s
++
)
{
out_expert_id
[
s
]
=
e
;
out_expert_id
[
s
]
=
curr_expert_id
;
unit_cnt
++
;
}
out_expert_id
+=
expert_slices
[
e
];
curr_expert_id
++
;
}
unit_cnt
*=
unit_size
;
return
;
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
ec959387
...
...
@@ -10,3 +10,4 @@
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/batched_transpose/kernel/batched_transpose_kernel.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
BatchedTransposeHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
index_t
batch
;
index_t
height
;
index_t
width
;
// index_t dim_blocks;
index_t
dim_stride
;
index_t
dim_block_h
;
index_t
dim_block_w
;
};
template
<
typename
Pipeline_
>
struct
BatchedTransposeKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
Type
=
typename
Problem
::
InputType
;
struct
BatchedTransposeKargs
{
const
void
*
p_input
;
void
*
p_output
;
index_t
batch
;
index_t
height
;
index_t
width
;
index_t
dim_stride
;
};
using
Kargs
=
BatchedTransposeKargs
;
using
Hargs
=
BatchedTransposeHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
size_t
grid_size_x
=
(
h
.
width
+
h
.
dim_block_w
-
1
)
/
h
.
dim_block_w
;
size_t
grid_size_y
=
(
h
.
height
+
h
.
dim_block_h
-
1
)
/
h
.
dim_block_h
;
size_t
grid_size_z
=
h
.
batch
;
return
dim3
(
grid_size_x
,
grid_size_y
,
grid_size_z
);
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
batch
=
h
.
batch
;
k
.
height
=
h
.
height
;
k
.
width
=
h
.
width
;
k
.
dim_stride
=
h
.
dim_stride
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
kBlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
ck_tile
::
index_t
kMPerThread
=
Problem
::
kMPerThread
;
static
constexpr
ck_tile
::
index_t
kNPerThread
=
Problem
::
kNPerThread
;
static_assert
(
kMPerThread
==
1
&&
kNPerThread
==
1
);
const
auto
iDim
=
blockIdx
.
z
;
const
auto
x_m_n
=
[
&
]()
{
const
auto
x_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
Type
*>
(
kargs
.
p_input
)
+
iDim
*
kargs
.
dim_stride
,
make_tuple
(
kargs
.
height
,
kargs
.
width
),
make_tuple
(
kargs
.
width
,
1
),
number
<
kNPerThread
>
{},
// TODO thread load value
number
<
1
>
{});
return
pad_tensor_view
(
x_dram_naive
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
const
auto
iM
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
x
*
kMPerBlock
);
const
auto
iN
=
__builtin_amdgcn_readfirstlane
(
blockIdx
.
y
*
kNPerBlock
);
const
auto
y_n_m
=
[
&
]()
{
const
auto
y_dram_naive
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
Type
*>
(
kargs
.
p_output
)
+
iDim
*
kargs
.
dim_stride
,
make_tuple
(
kargs
.
width
,
kargs
.
height
),
make_tuple
(
kargs
.
height
,
1
),
number
<
kMPerThread
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
y_dram_naive
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kMPerBlock
>
{}),
sequence
<
kPadN
,
kPadM
>
{});
}();
auto
x_block_window
=
make_tile_window
(
x_m_n
,
make_tuple
(
number
<
kMPerBlock
>
{},
number
<
kNPerBlock
>
{}),
{
static_cast
<
ck_tile
::
index_t
>
(
iM
*
kMPerBlock
),
static_cast
<
ck_tile
::
index_t
>
(
iN
*
kNPerBlock
)});
auto
y_block_window
=
make_tile_window
(
y_n_m
,
make_tuple
(
number
<
kNPerBlock
>
{},
number
<
kMPerBlock
>
{}),
{
static_cast
<
ck_tile
::
index_t
>
(
iN
*
kNPerBlock
),
static_cast
<
ck_tile
::
index_t
>
(
iM
*
kMPerBlock
)});
Pipeline
{}(
x_block_window
,
y_block_window
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_pipeline.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
BatchedTransposePolicy
>
struct
BatchedTransposePipeline
{
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
InputType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
InputType
>
;
static
constexpr
ck_tile
::
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
ck_tile
::
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
index_t
AlignmentM
=
Problem
::
AlignmentM
;
static
constexpr
index_t
AlignmentN
=
Problem
::
AlignmentN
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
template
<
typename
InputWindow
,
typename
OutputWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
)
{
auto
inp_win
=
make_tile_window
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>());
auto
out_win
=
make_tile_window
(
out_window
,
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
x
=
load_tile
(
inp_win
);
// x->thread input_win->block
auto
y
=
make_static_distributed_tensor
<
InputType
>
(
Policy
::
template
MakeOutputDistribution
<
Problem
>());
constexpr
auto
span_2d_x
=
decltype
(
x
)
::
get_distributed_spans
();
sweep_tile_span
(
span_2d_x
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d_x
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx1
,
idx0
);
y
(
i_j_idx
)
=
x
(
i_j_idx
);
});
});
store_tile
(
out_win
,
y
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_policy.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
BatchedTransposePolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeInputDistribution
()
{
using
S
=
Problem
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
,
S
::
kMPerThread
>
,
sequence
<
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
1
,
2
>
,
sequence
<
2
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOutputDistribution
()
{
using
S
=
Problem
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
kNWarpPerBlock
,
S
::
kNThreadPerWarp
,
S
::
kNPerThread
>
,
sequence
<
S
::
kMWarpPerBlock
,
S
::
kMThreadPerWarp
,
S
::
kMPerThread
>>
,
tuple
<
sequence
<
2
,
1
>
,
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>
,
sequence
<
1
,
1
>>
,
sequence
<
2
,
1
>
,
sequence
<
2
,
2
>>
{});
}
};
}
// namespace ck_tile
include/ck_tile/ops/batched_transpose/pipeline/batched_transpose_problem.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
#define VectorLoadSize 16
namespace
ck_tile
{
template
<
typename
InputType_
,
typename
BlockTile
,
// Sequence<...
typename
WarpTile
,
// Sequence<...
typename
ThreadTile
,
// Sequence<...
bool
kPadM_
=
true
,
bool
kPadN_
=
true
>
struct
BatchedTransposeProblem
{
using
InputType
=
remove_cvref_t
<
InputType_
>
;
static
constexpr
index_t
kMPerThread
=
ThreadTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerThread
=
ThreadTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMPerWarp
=
WarpTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerWarp
=
WarpTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMThreadPerWarp
=
kMPerWarp
/
kMPerThread
;
static
constexpr
index_t
kNThreadPerWarp
=
kNPerWarp
/
kNPerThread
;
static
constexpr
index_t
kMPerBlock
=
BlockTile
::
at
(
number
<
0
>
{});
static
constexpr
index_t
kNPerBlock
=
BlockTile
::
at
(
number
<
1
>
{});
static
constexpr
index_t
kMWarpPerBlock
=
kMPerBlock
/
kMPerWarp
;
static
constexpr
index_t
kNWarpPerBlock
=
kNPerBlock
/
kNPerWarp
;
static
constexpr
index_t
kBlockSize
=
kMThreadPerWarp
*
kNThreadPerWarp
*
kMWarpPerBlock
*
kNWarpPerBlock
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
index_t
AlignmentM
=
kPadM
?
VectorLoadSize
/
sizeof
(
InputType
)
:
1
;
// TODO
static
constexpr
index_t
AlignmentN
=
kPadN
?
VectorLoadSize
/
sizeof
(
InputType
)
:
1
;
};
}
// namespace ck_tile
include/ck_tile/ops/common.hpp
View file @
ec959387
...
...
@@ -5,3 +5,4 @@
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/common/utils.hpp
0 → 100644
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <iostream>
#include <string>
#include "ck_tile/core.hpp"
namespace
ck_tile
{
// clang-format off
template
<
typename
T
>
struct
typeToStr
;
template
<
>
struct
typeToStr
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
typeToStr
<
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
typeToStr
<
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
typeToStr
<
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
typeToStr
<
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
template
<
>
struct
typeToStr
<
int8_t
>
{
static
constexpr
const
char
*
name
=
"int8"
;
};
// clang-format on
template
<
typename
ADataType_
,
typename
BDataType_
>
std
::
string
gemm_prec_str
()
{
std
::
string
base_str
=
std
::
string
(
typeToStr
<
ADataType_
>::
name
);
if
(
!
std
::
is_same_v
<
ADataType_
,
BDataType_
>
)
{
base_str
+=
"_"
+
std
::
string
(
typeToStr
<
BDataType_
>::
name
);
}
return
base_str
;
}
}
// namespace ck_tile
include/ck_tile/ops/elementwise.hpp
View file @
ec959387
...
...
@@ -6,3 +6,4 @@
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/epilogue.hpp
View file @
ec959387
...
...
@@ -8,3 +8,4 @@
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/epilogue/cshuffle_epilogue.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#
define CK_TILE_MAX_RANK 5
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#
include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
// this epilogue aiming to store a matrix with different layout from the shared memory to the global
// memory.
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
kTilePermute_
,
index_t
kRank_
,
index_t
kPerm0
,
index_t
kPerm1
,
index_t
TileSize0
,
index_t
TileSize1
,
index_t
kPerm2
=
0
,
index_t
kPerm3
=
0
,
index_t
kPerm4
=
0
,
index_t
TileSize2
=
0
,
index_t
TileSize3
=
0
,
index_t
TileSize4
=
0
>
typename
CLayout_
,
index_t
kBlockSize_
,
index_t
kM_
,
index_t
kN_
,
index_t
kMWave_
,
index_t
kNWave_
,
index_t
kMPerXdl_
,
index_t
kNPerXdl_
,
index_t
kKPerXdl_
,
bool
isCTransposed_
>
struct
CShuffleEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kTilePermute
=
kTilePermute_
;
static
constexpr
index_t
kRank
=
kRank_
;
static
constexpr
index_t
kPerm
[
CK_TILE_MAX_RANK
]
=
{
kPerm0
,
kPerm1
,
kPerm2
,
kPerm3
,
kPerm4
};
static
constexpr
index_t
tile_sizes
[
CK_TILE_MAX_RANK
]
=
{
TileSize0
,
TileSize1
,
TileSize2
,
TileSize3
,
TileSize4
};
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kBlockSize
=
kBlockSize_
;
static
constexpr
index_t
kMPerBlock
=
kM_
;
static
constexpr
index_t
kNPerBlock
=
kN_
;
static
constexpr
index_t
kMWave
=
kMWave_
;
static
constexpr
index_t
kNWave
=
kNWave_
;
static
constexpr
index_t
kMPerXdl
=
kMPerXdl_
;
static
constexpr
index_t
kNPerXdl
=
kNPerXdl_
;
static
constexpr
index_t
kKPerXdl
=
kKPerXdl_
;
static
constexpr
index_t
isCTransposed
=
isCTransposed_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
CShuffleEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
const
index_t
*
kPerm
=
Problem
::
kPerm
;
static
constexpr
bool
kTilePermute
=
Problem
::
kTilePermute
;
static
constexpr
index_t
kRank
=
Problem
::
kRank
;
const
index_t
*
tile_sizes
=
Problem
::
tile_sizes
;
// No additional shared memory needed
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
kBlockSize
=
Problem
::
kBlockSize
;
static
constexpr
index_t
kMPerBlock
=
Problem
::
kMPerBlock
;
static
constexpr
index_t
kNPerBlock
=
Problem
::
kNPerBlock
;
static
constexpr
index_t
kMWave
=
Problem
::
kMWave
;
static
constexpr
index_t
kNWave
=
Problem
::
kNWave
;
static
constexpr
index_t
kMPerXdl
=
Problem
::
kMPerXdl
;
static
constexpr
index_t
kNPerXdl
=
Problem
::
kNPerXdl
;
static
constexpr
index_t
kKPerXdl
=
Problem
::
kKPerXdl
;
static
constexpr
index_t
isCTransposed
=
Problem
::
isCTransposed
;
static
constexpr
index_t
kMPerIteration
=
kMPerXdl
*
kMWave
;
static
constexpr
index_t
kNPerIteration
=
kNPerXdl
*
kNWave
;
using
WG
=
WarpGemmMfmaDispatcher
<
ODataType
,
ODataType
,
AccDataType
,
kMPerXdl
,
kNPerXdl
,
kKPerXdl
,
isCTransposed
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
using
CWarpTensor
=
typename
WG
::
CWarpTensor
;
/**
* @brief Get the vector store size for C tensor.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template
<
typename
ODataType
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
// TODO: At now CShuffle doesn't allow to vector store after permute.
// It should be fixed and this function should return true.
return
false
;
constexpr
index_t
MaxVectorStoreSize
=
16
;
return
MaxVectorStoreSize
/
sizeof
(
ODataType
);
}
template
<
typename
OAccTi
le
>
CK_TILE_DEVICE
void
permute_tile_data
(
OAccTile
&
o_acc_tile
)
template
<
typename
Prob
le
m
>
CK_TILE_
HOST_
DEVICE
static
constexpr
auto
MakeLdsBlockDescriptor
(
)
{
using
DataType
=
typename
OAccTile
::
DataType
;
// Get thread buffer
auto
&
thread_buf
=
o_acc_tile
.
get_thread_buffer
();
// Create a temporary buffer to hold the permuted data
thread_buffer
<
DataType
,
OAccTile
::
kThreadElementSpaceSize
>
permuted_thread_buf
;
// Get the lengths of each dimension
auto
thread_tensor_lengths
=
o_acc_tile
.
get_lengths
();
// Total number of elements
index_t
total_elements
=
OAccTile
::
kThreadElementSpaceSize
;
// Iterate over all elements
for
(
index_t
linear_idx
=
0
;
linear_idx
<
total_elements
;
++
linear_idx
)
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
// Convert linear index to multi-dimensional indices
array
<
index_t
,
kRank
>
indices
;
index_t
remaining
=
linear_idx
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
indices
(
rev_i
)
=
remaining
%
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
remaining
/=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Apply the permutation
array
<
index_t
,
kRank
>
permuted_indices
;
static_for
<
0
,
kRank
,
1
>
{}(
[
&
](
auto
i
)
{
permuted_indices
(
i
)
=
indices
.
get
(
number
<
Problem
::
kPerm
[
i
]
>
{});
});
// Compute offsets
index_t
dst_offset
=
0
;
index_t
stride
=
1
;
static_for
<
0
,
kRank
,
1
>
{}([
&
](
auto
i
)
{
constexpr
auto
rev_i
=
kRank
-
1
-
i
;
dst_offset
+=
permuted_indices
[
rev_i
]
*
stride
;
stride
*=
thread_tensor_lengths
.
get
(
number
<
rev_i
>
{});
});
// Move the data
permuted_thread_buf
(
dst_offset
)
=
thread_buf
[
linear_idx
];
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
make_tuple
(
number
<
kNWave
*
kNPerXdl
>
{},
number
<
1
>
{}));
}
// Copy the permuted data back to the original thread buffer
for
(
index_t
i
=
0
;
i
<
total_elements
;
++
i
)
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
return
make_naive_tensor_descriptor
(
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
make_tuple
(
number
<
1
>
{},
number
<
kMWave
*
kMPerXdl
>
{}));
}
else
{
thread_buf
.
set_as
(
i
,
permuted_thread_buf
.
get
(
i
)
);
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
template
<
typename
ODramWindowTmp
,
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
kMWave
*
kNWave
*
kMPerXdl
*
kNPerXdl
*
sizeof
(
ODataType
);
}
template
<
typename
ODramWindow
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindow
&
out_dram_window
,
const
OAccTile
&
o_acc_tile
,
void
*
p_smem
)
{
const
auto
&
current_window_origin
=
o_dram_window_tmp
.
get_window_origin
();
// Compute the tile coordinates by dividing the window origin by the tile sizes
index_t
tile_coords
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
tile_coords
[
i
]
=
current_window_origin
[
i
]
/
tile_sizes
[
i
];
// printf("The tile_coord is: %d", tile_coords[i]);
}
// Apply the permutation to the tile coordinates
index_t
permuted_tile_coords
[
CK_TILE_MAX_RANK
];
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_tile_coords
[
i
]
=
tile_coords
[
kPerm
[
i
]];
// printf("The new permuted_tile_coords is: %d", permuted_tile_coords[i]);
}
// Compute the permuted window origin
index_t
permuted_window_origin
[
CK_TILE_MAX_RANK
]
=
{
0
};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
permuted_window_origin
[
i
]
=
permuted_tile_coords
[
i
]
*
tile_sizes
[
i
];
// printf("The new permuted_window_origin is: %d", permuted_window_origin[i]);
}
typename
ODramWindowTmp
::
BottomTensorIndex
step
=
{};
for
(
index_t
i
=
0
;
i
<
kRank
;
++
i
)
{
step
[
i
]
=
permuted_window_origin
[
i
]
-
current_window_origin
[
i
];
}
const
index_t
iMWarp
=
get_warp_id
()
/
kNWave
;
const
index_t
iNWarp
=
get_warp_id
()
-
iMWarp
*
kNWave
;
constexpr
auto
lds_block_desc
=
MakeLdsBlockDescriptor
<
Problem
>
();
auto
o_lds_block
=
make_tensor_view
<
address_space_enum
::
lds
>
(
static_cast
<
ODataType
*>
(
p_smem
),
lds_block_desc
);
auto
in_lds_window
=
make_tile_window
(
o_lds_block
,
make_tuple
(
number
<
kMPerXdl
>
{},
number
<
kNPerXdl
>
{}),
{
number
<
kMPerXdl
>
{}
*
iMWarp
,
number
<
kNPerXdl
>
{}
*
iNWarp
});
auto
out_lds_window
=
make_tile_window
(
o_lds_block
,
make_tuple
(
number
<
kMWave
*
kMPerXdl
>
{},
number
<
kNWave
*
kNPerXdl
>
{}),
{
0
,
0
});
using
SFC
=
space_filling_curve
<
sequence
<
kMPerBlock
,
kNPerBlock
>
,
sequence
<
0
,
1
>
,
sequence
<
kMPerXdl
*
kMWave
,
kNPerXdl
*
kNWave
>>
;
constexpr
index_t
num_access
=
SFC
::
get_num_of_access
();
using
TileEncodingPattern
=
TileDistributionEncodingPattern2D
<
kBlockSize
,
kMPerIteration
,
kNPerIteration
,
GetVectorSizeC
<
ODataType
>
(),
tile_distribution_pattern
::
thread_raked
>
;
constexpr
auto
dram_tile_distribution
=
TileEncodingPattern
::
Make2DStaticTileDistribution
();
constexpr
auto
c_warp_y_lengths
=
to_sequence
(
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
());
constexpr
auto
c_warp_y_index_zeros
=
uniform_sequence_gen_t
<
CWarpDstr
::
NDimY
,
0
>
{};
CWarpTensor
c_warp_in_tensor
;
static_for
<
0
,
num_access
,
1
>
{}([
&
](
auto
iAccess
)
{
constexpr
auto
idx_y_start
=
SFC
::
get_index
(
iAccess
);
constexpr
auto
mIter
=
number
<
idx_y_start
.
at
(
number
<
0
>
{})
/
(
kMPerXdl
*
kMWave
)
>
{};
constexpr
auto
nIter
=
number
<
idx_y_start
.
at
(
number
<
1
>
{})
/
(
kNPerXdl
*
kNWave
)
>
{};
c_warp_in_tensor
.
get_thread_buffer
()
=
o_acc_tile
.
get_y_sliced_thread_data
(
merge_sequences
(
sequence
<
mIter
,
nIter
>
{},
c_warp_y_index_zeros
),
merge_sequences
(
sequence
<
1
,
1
>
{},
c_warp_y_lengths
));
const
auto
c_warp_in_tensor_casted
=
cast_tile
<
ODataType
>
(
c_warp_in_tensor
);
block_sync_lds
();
store_tile
(
in_lds_window
,
c_warp_in_tensor_casted
);
block_sync_lds
();
const
auto
c_out_tensor
=
load_tile
(
make_tile_window
(
out_lds_window
,
dram_tile_distribution
));
// Move the window
move_tile_window
(
o_dram_window_tmp
,
step
);
// Permute the data within the tile if necessary
if
constexpr
(
kTilePermute
)
{
permute_tile_data
(
o_acc_tile
);
}
// Store the tile data to the permuted location
if
constexpr
(
kPadM
||
kPadN
)
{
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
{
store_tile
_raw
(
o_dram_window
_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
)
);
store_tile
(
o
ut
_dram_window
,
c_out_tensor
);
}
else
{
update_tile
_raw
(
o_dram_window
_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
)
);
update_tile
(
o
ut
_dram_window
,
c_out_tensor
);
}
buffer_store_fence
();
}
else
{
if
constexpr
(
out_memory_data_op
==
memory_operation_enum
::
set
)
if
constexpr
(
iAccess
!=
num_access
-
1
)
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
constexpr
auto
step
=
SFC
::
get_forward_step
(
iAccess
);
move_tile_window
(
out_dram_window
,
{
step
.
at
(
number
<
0
>
{}),
step
.
at
(
number
<
1
>
{})});
}
else
{
update_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
}
}
});
}
};
}
// namespace ck_tile
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
ec959387
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
5
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace
ck_tile
{
...
...
@@ -23,6 +25,26 @@ struct Default2DEpilogueProblem
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
};
template
<
typename
AccDataType_
,
typename
ODataType_
,
typename
CLayout_
,
bool
kPadM_
,
bool
kPadN_
,
index_t
kMPerXdl_
,
index_t
kNPerXdl_
,
index_t
kKPerXdl_
,
bool
isCTransposed_
,
bool
UseRawStore_
=
true
>
struct
DefaultGemm2DEpilogueProblem
:
public
Default2DEpilogueProblem
<
AccDataType_
,
ODataType_
,
kPadM_
,
kPadN_
,
UseRawStore_
>
{
using
CLayout
=
remove_cvref_t
<
CLayout_
>
;
static
constexpr
index_t
kMPerXdl
=
kMPerXdl_
;
static
constexpr
index_t
kNPerXdl
=
kNPerXdl_
;
static
constexpr
index_t
kKPerXdl
=
kKPerXdl_
;
static
constexpr
index_t
isCTransposed
=
isCTransposed_
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
Default2DEpilogue
{
...
...
@@ -35,14 +57,13 @@ struct Default2DEpilogue
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
bool
IsOutputTransposed
()
{
return
false
;
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
OAccTile
,
memory_operation_enum
out_memory_data_op
=
memory_operation_enum
::
set
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
)
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
const
OAccTile
&
o_acc_tile
,
void
*
=
nullptr
)
{
// TODO: this is ugly
...
...
@@ -71,4 +92,76 @@ struct Default2DEpilogue
}
}
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DefaultGemm2DEpilogue
:
public
Default2DEpilogue
<
Problem_
,
Policy_
>
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
CLayout
=
remove_cvref_t
<
typename
Problem
::
CLayout
>
;
static
constexpr
index_t
kMPerXdl
=
Problem
::
kMPerXdl
;
static
constexpr
index_t
kNPerXdl
=
Problem
::
kNPerXdl
;
static
constexpr
index_t
kKPerXdl
=
Problem
::
kKPerXdl
;
static
constexpr
index_t
isCTransposed
=
Problem
::
isCTransposed
;
using
WG
=
WarpGemmMfmaDispatcher
<
ODataType
,
ODataType
,
AccDataType
,
kMPerXdl
,
kNPerXdl
,
kKPerXdl
,
isCTransposed
>
;
using
CWarpDstr
=
typename
WG
::
CWarpDstr
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetVectorSizeC
()
{
// N is contiguous dimension
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
RowMajor
>
)
{
if
constexpr
(
isCTransposed
)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
else
{
// In this case each thread has just a single item in Ndim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
}
// M is contiguous dimension
else
if
constexpr
(
std
::
is_same_v
<
CLayout
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
{
if
constexpr
(
isCTransposed
)
{
// In this case each thread has just a single item in Mdim
return
WG
::
WarpGemmAttribute
::
Impl
::
kCNLane
/
WG
::
kN
;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr
index_t
NDimY
=
CWarpDstr
::
NDimY
;
constexpr
auto
c_warp_y_lengths
=
CWarpDstr
{}.
get_ys_to_d_descriptor
().
get_lengths
();
static_assert
(
WG
::
WarpGemmAttribute
::
Impl
::
kCM1PerLane
==
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{}));
return
c_warp_y_lengths
.
get
(
number
<
NDimY
-
1
>
{});
}
}
else
{
static_assert
(
false
,
"Unsupported CLayout!"
);
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/flatmm.hpp
View file @
ec959387
...
...
@@ -9,3 +9,4 @@
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16.inc
View file @
ec959387
...
...
@@ -824,4 +824,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/flatmm/block/uk/flatmm_sn_uk_gfx9_32x128x512_1x4x1_16x16x16_itl.inc
View file @
ec959387
...
...
@@ -722,4 +722,4 @@
#undef _UK_PK_CVT_
#undef _UK_ATOMIC_ADD_
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/flatmm/block/uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc
View file @
ec959387
...
...
@@ -771,4 +771,4 @@
#undef _UK_MFMA_
#undef CK_TILE_FLATMM_UK_2B
#undef CK_TILE_FLATMM_UK_MFMA
// clang-format on
// clang-format on
include/ck_tile/ops/fmha.hpp
View file @
ec959387
...
...
@@ -44,3 +44,4 @@
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/utils.hpp"
Prev
1
…
9
10
11
12
13
14
15
16
17
…
20
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment