Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
0475a327
Commit
0475a327
authored
Nov 04, 2024
by
dummycoderfe
Browse files
Merge branch 'ck_tile/layernorm2d_fwd_optimize' into ck_tile/ln_add_cache_clear
parents
c9b961ab
27ff3dec
Changes
267
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
639 additions
and
29 deletions
+639
-29
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
...de/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
+16
-0
include/ck_tile/ops/topk.hpp
include/ck_tile/ops/topk.hpp
+9
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
+113
-0
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
...e/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
+22
-0
include/ck_tile/ops/topk_softmax.hpp
include/ck_tile/ops/topk_softmax.hpp
+11
-0
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
...e/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
+166
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
...k_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
+123
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
...opk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
+63
-0
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
...pk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
+46
-0
include/ck_tile/ops/welford.hpp
include/ck_tile/ops/welford.hpp
+1
-0
include/ck_tile/ops/welford/block/block_welford.hpp
include/ck_tile/ops/welford/block/block_welford.hpp
+8
-7
include/ck_tile/ops/welford/thread/thread_welford.hpp
include/ck_tile/ops/welford/thread/thread_welford.hpp
+7
-6
include/ck_tile/remod.py
include/ck_tile/remod.py
+3
-2
library/src/tensor_operation_instance/gpu/CMakeLists.txt
library/src/tensor_operation_instance/gpu/CMakeLists.txt
+27
-3
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
...device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
+6
-4
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
...f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
+3
-2
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
...f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
+6
-4
profiler/src/profile_gemm_universal.cpp
profiler/src/profile_gemm_universal.cpp
+7
-1
test/CMakeLists.txt
test/CMakeLists.txt
+1
-0
test/ck_tile/CMakeLists.txt
test/ck_tile/CMakeLists.txt
+1
-0
No files found.
include/ck_tile/ops/softmax/block/block_softmax_2d_problem.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
template
<
typename
DataType_
>
struct
BlockSoftmax2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk/block/block_topk_stream_2d.hpp"
#include "ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk/block/block_topk_stream_2d.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
BlockTopkStream2D
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
DataType
=
typename
Problem
::
DataType
;
using
IndexType
=
typename
Problem
::
IndexType
;
// TODO: if DataType is subdword, need pack into single dword to use argmax
struct
ArgmaxPacket
{
DataType
arg
;
index_t
value
;
};
template
<
typename
DistributedTensor
,
typename
OutWindow
,
typename
IdxWindow
,
index_t
dim
=
1
>
CK_TILE_DEVICE
void
operator
()(
const
DistributedTensor
&
x
,
const
OutWindow
&
out_window
,
const
IdxWindow
&
idx_window
,
index_t
k
,
number
<
dim
>
=
{})
{
OutWindow
out_window_tmp
=
out_window
;
IdxWindow
idx_window_tmp
=
idx_window
;
static_assert
(
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
typename
OutWindow
::
DataType
>
&&
std
::
is_same_v
<
typename
DistributedTensor
::
DataType
,
DataType
>
);
static_assert
(
std
::
is_same_v
<
typename
IdxWindow
::
DataType
,
IndexType
>
);
DistributedTensor
x_tmp
=
x
;
constexpr
auto
dst_dist
=
typename
IdxWindow
::
TileDstr
{};
// argmax for topk
const
auto
f_argmax
=
[](
ArgmaxPacket
e0
,
ArgmaxPacket
e1
)
{
return
e0
.
arg
>
e1
.
arg
?
e0
:
e1
;
};
for
(
index_t
i_k
=
0
;
i_k
<
k
;
i_k
++
)
{
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
auto
packet
=
[
&
]()
{
auto
tmp
=
make_static_distributed_tensor
<
ArgmaxPacket
>
(
x
.
get_tile_distribution
());
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
tmp
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
t
;
t
.
arg
=
x_tmp
(
i_j_idx
);
// !!! we reference x here
t
.
value
=
tile_idx
.
at
(
number
<
1
>
{});
tmp
(
i_j_idx
)
=
t
;
});
});
return
tmp
;
}();
auto
argmax_init
=
ArgmaxPacket
{
-
numeric
<
DataType
>::
infinity
(),
0
};
auto
r
=
block_tile_reduce
<
ArgmaxPacket
>
(
packet
,
sequence
<
1
>
{},
f_argmax
,
argmax_init
);
block_tile_reduce_xor_sync
(
r
,
f_argmax
);
auto
o
=
make_static_distributed_tensor
<
DataType
>
(
dst_dist
);
auto
i
=
make_static_distributed_tensor
<
IndexType
>
(
dst_dist
);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
ArgmaxPacket
tmp
=
r
(
i_j_idx
);
o
(
i_j_idx
)
=
tmp
.
arg
;
i
(
i_j_idx
)
=
tmp
.
value
;
});
});
// update value
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
const
auto
tile_idx
=
get_x_indices_from_distributed_indices
(
x
.
get_tile_distribution
(),
make_tuple
(
idx0
,
idx1
));
auto
col_id
=
tile_idx
.
at
(
number
<
1
>
{});
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
x_tmp
(
i_j_idx
)
=
(
col_id
==
r
(
i_j_idx
).
value
)
?
-
numeric
<
DataType
>::
infinity
()
:
x_tmp
(
i_j_idx
);
});
});
if
(
threadIdx
.
x
%
Problem
::
ColLanes
==
0
)
{
store_tile
(
out_window_tmp
,
o
);
store_tile
(
idx_window_tmp
,
i
);
}
move_tile_window
(
out_window_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
move_tile_window
(
idx_window_tmp
,
{
number
<
0
>
{},
number
<
1
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk/block/block_topk_stream_2d_problem.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
simple 2d topk implementation, along row (dim=1)
requirement:
1). each row is within a warp
*/
template
<
typename
DataType_
,
typename
IndexType_
,
index_t
ColLanes_
>
struct
BlockTopkStream2DProblem
{
using
DataType
=
remove_cvref_t
<
DataType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
ColLanes
=
ColLanes_
;
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/topk_softmax/kernel/topk_softmax_kernel.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include "ck_tile/ops/elementwise.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
struct
TopkSoftmaxHostArgs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
template
<
typename
Pipeline_
>
struct
TopkSoftmaxKernel
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
remove_cvref_t
<
typename
Pipeline
::
Problem
>
;
using
InputType
=
typename
Problem
::
InputType
;
using
WeightType
=
typename
Problem
::
WeightType
;
using
IndexType
=
typename
Problem
::
IndexType
;
struct
TopkSoftmaxKargs
{
const
void
*
p_input
;
void
*
p_output
;
void
*
p_indices
;
index_t
num_rows
;
index_t
num_experts
;
index_t
topk
;
index_t
stride_input
;
// row stride for input, at least experts
index_t
stride_output
;
// row stride for output/indices, at least tpok
};
using
Kargs
=
TopkSoftmaxKargs
;
using
Hargs
=
TopkSoftmaxHostArgs
;
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
h
)
{
if
constexpr
(
Problem
::
LaunchType
>
0
)
{
int
num_cu
=
[
&
]()
{
hipDeviceProp_t
dev_prop
;
hipDevice_t
dev
;
HIP_CHECK_ERROR
(
hipGetDevice
(
&
dev
));
HIP_CHECK_ERROR
(
hipGetDeviceProperties
(
&
dev_prop
,
dev
));
return
dev_prop
.
multiProcessorCount
;
}();
return
dim3
(
num_cu
*
Problem
::
LaunchType
);
}
else
{
const
int
num_warps
=
(
h
.
num_rows
+
Problem
::
RowsPerWarp
-
1
)
/
Problem
::
RowsPerWarp
;
const
int
num_blocks
=
(
num_warps
+
Problem
::
WarpsPerBlock
-
1
)
/
Problem
::
WarpsPerBlock
;
return
dim3
(
num_blocks
);
}
}
CK_TILE_HOST
static
constexpr
auto
MakeKargs
(
const
Hargs
&
h
)
{
Kargs
k
;
k
.
p_input
=
h
.
p_input
;
k
.
p_output
=
h
.
p_output
;
k
.
p_indices
=
h
.
p_indices
;
k
.
num_rows
=
h
.
num_rows
;
k
.
num_experts
=
h
.
num_experts
;
k
.
topk
=
h
.
topk
;
k
.
stride_input
=
h
.
stride_input
;
k
.
stride_output
=
h
.
stride_output
;
return
k
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockSize
;
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
index_t
block_row_id
=
static_cast
<
index_t
>
(
blockIdx
.
x
*
Problem
::
RowsPerBlock
);
if
(
block_row_id
>
kargs
.
num_rows
)
return
;
index_t
block_os_inp
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_input
);
index_t
block_os_out
=
__builtin_amdgcn_readfirstlane
(
block_row_id
*
kargs
.
stride_output
);
index_t
num_rows_rem
=
__builtin_amdgcn_readfirstlane
(
kargs
.
num_rows
-
block_row_id
);
const
auto
input_window
=
[
&
]()
{
const
InputType
*
p_input
=
reinterpret_cast
<
const
InputType
*>
(
kargs
.
p_input
)
+
block_os_inp
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_input
,
make_tuple
(
num_rows_rem
,
kargs
.
num_experts
),
make_tuple
(
kargs
.
stride_input
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
sequence
<
0
,
1
>
{});
// out-most dim no need pad(leverage oob)
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
Problem
::
Experts
>
{}),
{
0
,
0
});
}();
auto
output_window
=
[
&
]()
{
WeightType
*
p_output
=
reinterpret_cast
<
WeightType
*>
(
kargs
.
p_output
)
+
block_os_out
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_output
,
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
make_tuple
(
kargs
.
stride_output
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
auto
indices_window
=
[
&
]()
{
IndexType
*
p_indices
=
reinterpret_cast
<
IndexType
*>
(
kargs
.
p_indices
)
+
block_os_out
;
auto
tmp
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
p_indices
,
make_tuple
(
num_rows_rem
,
kargs
.
topk
),
make_tuple
(
kargs
.
stride_output
,
1
),
number
<
Problem
::
VectorSize
>
{},
number
<
1
>
{});
auto
view
=
pad_tensor_view
(
tmp
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
sequence
<
0
,
0
>
{});
// 1. out-most dim no need pad(leverage oob)
// 2. we loop over topk 1-1, no need padding
return
make_tile_window
(
view
,
make_tuple
(
number
<
Problem
::
RowsPerBlock
>
{},
number
<
1
>
{}),
{
0
,
0
});
}();
Pipeline
{}(
input_window
,
output_window
,
indices_window
,
kargs
.
num_rows
,
kargs
.
num_experts
,
kargs
.
topk
,
block_row_id
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_pipeline.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp"
#include <string>
#include <type_traits>
#ifndef TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
#define TOPK_SOFTMAX_USE_RAW_TILE_WINDOW 0
#endif
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
TopkSoftmaxWarpPerRowPolicy
>
struct
TopkSoftmaxWarpPerRowPipeline
{
// TODO: this kernel only support warp per row
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
Policy
=
remove_cvref_t
<
Policy_
>
;
using
WeightType
=
typename
Problem
::
WeightType
;
template
<
typename
InputWindow
,
typename
OutputWindow
,
typename
IndexWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
InputWindow
&
input_window
,
OutputWindow
&
out_window
,
IndexWindow
&
idx_window
,
index_t
rows
,
index_t
experts
,
index_t
k
,
index_t
block_row_id
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
auto
inp_win
=
make_tile_window_linear_raw
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#else
auto
inp_win
=
make_tile_window_linear
(
input_window
,
Policy
::
template
MakeInputDistribution
<
Problem
>(),
sequence
<
0
,
1
>
{});
#endif
auto
out_win
=
make_tile_window_linear
(
out_window
.
get_bottom_tensor_view
(),
out_window
.
get_window_lengths
(),
out_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
idx_win
=
make_tile_window_linear
(
idx_window
.
get_bottom_tensor_view
(),
idx_window
.
get_window_lengths
(),
idx_window
.
get_window_origin
(),
Policy
::
template
MakeOutputDistribution
<
Problem
>());
auto
softmax
=
Policy
::
template
GetSoftmax
<
Problem
>();
auto
topk
=
Policy
::
template
GetTopk
<
Problem
>();
const
index_t
grid_rows_per_loop
=
gridDim
.
x
*
Problem
::
RowsPerBlock
;
while
(
1
)
{
#if TOPK_SOFTMAX_USE_RAW_TILE_WINDOW
__builtin_amdgcn_sched_barrier
(
0
);
auto
x
=
load_tile_raw
(
inp_win
,
number
<-
1
>
{},
bool_constant
<
true
>
{},
bool_constant
<
true
>
{});
buffer_load_fence
(
number
<
0
>
{});
__builtin_amdgcn_sched_barrier
(
0
);
#else
auto
x
=
load_tile
(
inp_win
);
#endif
// cast and pad input data
auto
w
=
[
&
]()
{
#if 0
auto w_ = cast_tile<WeightType>(x);
constexpr auto span_2d = decltype(w_)::get_distributed_spans();
sweep_tile_span(span_2d[number<0>{}], [&](auto idx0) {
sweep_tile_span(span_2d[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
const auto x_indices = get_x_indices_from_distributed_indices(
w_.get_tile_distribution(), i_j_idx);
const auto current_expert = x_indices.at(number<1>{});
// set to -INF if OOB so that later softmax can work properly
w_(i_j_idx) = current_expert >= experts ? -numeric<WeightType>::infinity()
: w_(i_j_idx);
});
});
return w_;
#else
auto
w_
=
make_static_distributed_tensor
<
WeightType
>
(
x
.
get_tile_distribution
());
auto
w_f
=
[
&
](
auto
idx
)
{
w_
(
idx
)
=
type_convert
<
WeightType
>
(
x
(
idx
));
const
auto
x_indices
=
get_x_indices_from_distributed_indices
(
w_
.
get_tile_distribution
(),
idx
);
const
auto
current_expert
=
x_indices
.
at
(
number
<
1
>
{});
w_
(
idx
)
=
current_expert
>=
experts
?
-
numeric
<
WeightType
>::
infinity
()
:
w_
(
idx
);
};
tile_sweeper
ts
{
w_
,
w_f
};
ts
();
return
w_
;
#endif
}();
// softmax
auto
y
=
softmax
(
w
);
topk
(
y
,
out_win
,
idx_win
,
k
);
// check exit
if
constexpr
(
Problem
::
LaunchType
==
0
)
{
break
;
}
else
{
block_row_id
+=
grid_rows_per_loop
;
if
(
block_row_id
>=
rows
)
break
;
}
move_tile_window
(
inp_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
out_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
move_tile_window
(
idx_win
,
{
grid_rows_per_loop
,
number
<
0
>
{}});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_policy.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/softmax.hpp"
#include "ck_tile/ops/topk.hpp"
namespace
ck_tile
{
struct
TopkSoftmaxWarpPerRowPolicy
{
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeInputDistribution
()
{
// TODO: Y dim must have one dim that is not reduced
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
1
>
,
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
Problem
::
IssuesPerRow
,
Problem
::
LanesPerRow
,
Problem
::
VectorSize
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
1
>>
,
sequence
<
1
,
2
,
2
>
,
sequence
<
0
,
0
,
2
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeOutputDistribution
()
{
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
Problem
::
LanesPerRow
>
,
// repeat this one
tuple
<
sequence
<
Problem
::
IssuesPerCol
,
Problem
::
WarpsPerBlock
,
Problem
::
RowsPerWarpPerColIssue
>
,
sequence
<
1
>>
,
// each row write out single element
tuple
<
sequence
<
1
>
,
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
>
,
sequence
<
2
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetSoftmax
()
{
using
softmax_problem
=
BlockSoftmax2DProblem
<
typename
Problem
::
WeightType
>
;
return
BlockSoftmax2D
<
softmax_problem
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetTopk
()
{
using
topk_problem
=
BlockTopkStream2DProblem
<
typename
Problem
::
WeightType
,
typename
Problem
::
IndexType
,
Problem
::
LanesPerRow
>
;
// Note: replicate is LanesPerRow
return
BlockTopkStream2D
<
topk_problem
>
{};
}
};
}
// namespace ck_tile
include/ck_tile/ops/topk_softmax/pipeline/topk_softmax_warp_per_row_problem.hpp
0 → 100644
View file @
0475a327
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
InputType_
,
typename
WeightType_
,
typename
IndexType_
,
index_t
Experts_
,
index_t
IssuesPerCol_
=
2
,
// issue along col, to make sure block_reduce() OK
index_t
BytesPerIssue_
=
sizeof
(
InputType_
),
index_t
LaunchType_
=
0
,
// 0-streaming, >0, persistent #occupancy
index_t
BlockSize_
=
256
>
struct
TopkSoftmaxWarpPerRowProblem
{
// TODO: this kernel only support warp per row
using
InputType
=
remove_cvref_t
<
InputType_
>
;
using
WeightType
=
remove_cvref_t
<
WeightType_
>
;
using
IndexType
=
remove_cvref_t
<
IndexType_
>
;
static
constexpr
index_t
LaunchType
=
LaunchType_
;
static
constexpr
index_t
Experts
=
Experts_
;
static
constexpr
index_t
BytesPerIssue
=
BytesPerIssue_
;
static
constexpr
index_t
IssuesPerCol
=
IssuesPerCol_
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
static
constexpr
index_t
WarpSize
=
get_warp_size
();
static_assert
(
BytesPerIssue
%
sizeof
(
InputType
)
==
0
);
static
constexpr
index_t
VectorSize
=
BytesPerIssue
/
sizeof
(
InputType
);
static_assert
(
Experts
%
VectorSize
==
0
);
static
constexpr
index_t
LanesPerRow
=
min
(
Experts
/
VectorSize
,
WarpSize
);
static_assert
(
WarpSize
%
LanesPerRow
==
0
);
static
constexpr
index_t
RowsPerWarpPerColIssue
=
WarpSize
/
LanesPerRow
;
static
constexpr
index_t
RowsPerWarp
=
IssuesPerCol
*
RowsPerWarpPerColIssue
;
static
constexpr
index_t
IssuesPerRow
=
Experts
/
(
LanesPerRow
*
VectorSize
);
static
constexpr
index_t
WarpsPerBlock
=
BlockSize
/
WarpSize
;
static
constexpr
index_t
RowsPerBlock
=
RowsPerWarp
*
WarpsPerBlock
;
};
}
// namespace ck_tile
include/ck_tile/ops/welford.hpp
View file @
0475a327
...
@@ -6,4 +6,5 @@
...
@@ -6,4 +6,5 @@
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/block/block_welford_problem.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/welford/thread/thread_welford.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/welford/block/block_welford.hpp
View file @
0475a327
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
@@ -276,8 +276,8 @@ struct BlockWelfordCrossWarpSync
...
@@ -276,8 +276,8 @@ struct BlockWelfordCrossWarpSync
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
fp32x4_t
all_scratch
[
thread_buf_size
*
num_reduce_warps
];
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
static_for
<
0
,
num_reduce_warps
,
1
>
{}([
&
](
auto
i_1
)
{
all_scratch
[
i_0
*
num_warps
+
i_1
]
=
all_scratch
[
i_0
*
num_
reduce_
warps
+
i_1
]
=
smem_ptr
[
i_0
*
num_
reduce_
warps
+
local_smem_os
+
i_1
];
smem_ptr
[
i_0
*
num_warps
+
local_smem_os
+
i_1
];
});
});
});
});
block_sync_lds
();
// TODO: we don't need sync here
block_sync_lds
();
// TODO: we don't need sync here
...
@@ -286,7 +286,7 @@ struct BlockWelfordCrossWarpSync
...
@@ -286,7 +286,7 @@ struct BlockWelfordCrossWarpSync
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
static_for
<
0
,
thread_buf_size
,
1
>
{}([
&
](
auto
i_0
)
{
// TODO: use descriptor for this
// TODO: use descriptor for this
auto
v_local
=
all_scratch
[
i_0
*
num_warps
];
auto
v_local
=
all_scratch
[
i_0
*
num_
reduce_
warps
];
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_mean
=
bit_cast
<
DataType
>
(
v_local
[
0
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_var
=
bit_cast
<
DataType
>
(
v_local
[
1
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
auto
v_local_count
=
bit_cast
<
int
>
(
v_local
[
2
]);
...
@@ -294,7 +294,7 @@ struct BlockWelfordCrossWarpSync
...
@@ -294,7 +294,7 @@ struct BlockWelfordCrossWarpSync
// further reduce mean/var
// further reduce mean/var
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
static_for
<
0
,
num_reduce_warps
-
1
,
1
>
{}([
&
](
auto
i_1_n1
)
{
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
constexpr
auto
i_1
=
number
<
i_1_n1
+
1
>
{};
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_warps
+
i_1
];
const
fp32x4_t
v_remote
=
all_scratch
[
i_0
*
num_
reduce_
warps
+
i_1
];
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_mean
=
bit_cast
<
DataType
>
(
v_remote
[
0
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_var
=
bit_cast
<
DataType
>
(
v_remote
[
1
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
const
auto
v_remote_count
=
bit_cast
<
int
>
(
v_remote
[
2
]);
...
@@ -356,7 +356,8 @@ CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTe
...
@@ -356,7 +356,8 @@ CK_TILE_DEVICE constexpr void block_tile_welford_post_scale_var(VarDistributedTe
int
count
)
int
count
)
{
{
using
DataType
=
typename
VarDistributedTensor_
::
DataType
;
using
DataType
=
typename
VarDistributedTensor_
::
DataType
;
tile_elementwise_inout
([
&
count
](
auto
&
x
)
{
x
=
x
/
type_convert
<
DataType
>
(
count
);
},
tile_elementwise_inout
(
var_tensor
);
[
&
count
](
auto
&
x
)
{
x
=
x
*
__builtin_amdgcn_rcpf
(
type_convert
<
DataType
>
(
count
));
},
var_tensor
);
}
}
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/ops/welford/thread/thread_welford.hpp
View file @
0475a327
...
@@ -12,7 +12,7 @@ CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
...
@@ -12,7 +12,7 @@ CK_TILE_DEVICE void welford_update(T& mean, T& var, T x, int count)
{
{
// TODO: check nan? maybe no
// TODO: check nan? maybe no
T
delta
=
x
-
mean
;
T
delta
=
x
-
mean
;
mean
+=
delta
/
count
;
mean
+=
delta
*
__builtin_amdgcn_rcpf
(
count
)
;
T
delta2
=
x
-
mean
;
T
delta2
=
x
-
mean
;
var
+=
delta
*
delta2
;
var
+=
delta
*
delta2
;
}
}
...
@@ -21,11 +21,12 @@ template <typename T>
...
@@ -21,11 +21,12 @@ template <typename T>
CK_TILE_DEVICE
static
void
CK_TILE_DEVICE
static
void
welford_merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
welford_merge
(
T
&
mean_a
,
T
&
var_a
,
int
&
count_a
,
T
mean_b
,
T
var_b
,
int
count_b
)
{
{
int
count
=
count_a
+
count_b
;
int
count
=
count_a
+
count_b
;
T
count_
=
type_convert
<
T
>
(
count
);
T
count_
=
type_convert
<
T
>
(
count
);
T
count_a_
=
type_convert
<
T
>
(
count_a
);
T
count_a_
=
type_convert
<
T
>
(
count_a
);
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
count_b_
=
type_convert
<
T
>
(
count_b
);
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
/
count_
;
T
count_b_over_count
=
count
==
0
?
type_convert
<
T
>
(
0
)
:
count_b_
*
__builtin_amdgcn_rcpf
(
count_
);
T
delta
=
mean_b
-
mean_a
;
T
delta
=
mean_b
-
mean_a
;
mean_a
+=
delta
*
count_b_over_count
;
mean_a
+=
delta
*
count_b_over_count
;
...
...
include/ck_tile/remod.py
View file @
0475a327
from
datetime
import
datetime
import
pathlib
import
pathlib
from
pathlib
import
Path
from
pathlib
import
Path
import
subprocess
import
subprocess
...
@@ -8,8 +9,8 @@ NS = 'ck_tile'
...
@@ -8,8 +9,8 @@ NS = 'ck_tile'
OPS
=
'ops'
OPS
=
'ops'
OPS_COMMON
=
'common'
# common header will be duplicated into ops/* other module
OPS_COMMON
=
'common'
# common header will be duplicated into ops/* other module
HEADER_COMMON
=
"""// SPDX-License-Identifier: MIT
HEADER_COMMON
=
f
"""// SPDX-License-Identifier: MIT
// Copyright (c) 2018-
2024
, Advanced Micro Devices, Inc. All rights reserved.
\n
// Copyright (c) 2018-
{
datetime
.
now
().
year
}
, Advanced Micro Devices, Inc. All rights reserved.
\n
"""
"""
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
# aa/bb/cc/file.hpp -> (aa, bb, cc, file.hpp)
...
...
library/src/tensor_operation_instance/gpu/CMakeLists.txt
View file @
0475a327
...
@@ -67,6 +67,21 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -67,6 +67,21 @@ function(add_instance_library INSTANCE_NAME)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endif
()
endforeach
()
endforeach
()
# Do not build gemm_universal_f8 or gemm_multiply_multiply_f8 for any targets except gfx94
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_multiply_multiply_xdl_f8"
)
message
(
"removing gemm_multiply_multiply_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
foreach
(
source IN LISTS ARGN
)
if
(
NOT INST_TARGETS MATCHES
"gfx94"
AND source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"_f8_"
)
message
(
"removing gemm_universal_f8 instance
${
source
}
"
)
list
(
REMOVE_ITEM ARGN
"
${
source
}
"
)
endif
()
endforeach
()
endif
()
#only continue if there are some source files left on the list
#only continue if there are some source files left on the list
if
(
ARGN
)
if
(
ARGN
)
set
(
INST_OBJ
)
set
(
INST_OBJ
)
...
@@ -74,11 +89,20 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -74,11 +89,20 @@ function(add_instance_library INSTANCE_NAME)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
set
(
INST_TARGETS
${
SUPPORTED_GPU_TARGETS
}
)
if
(
source MATCHES
"_xdl"
)
if
(
source MATCHES
"_xdl"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
elseif
(
ARGN
MATCHES
"_wmma"
)
elseif
(
source
MATCHES
"_wmma"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx940 gfx941 gfx942 gfx1030
)
elseif
(
ARGN
MATCHES
"mha"
)
elseif
(
source
MATCHES
"mha"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
endif
()
#only build the fp8 gemm instances for gfx908/90a if the build argument is set
if
(
NOT CK_USE_FP8_ON_UNSUPPORTED_ARCH
)
if
(
source MATCHES
"gemm_xdl_universal"
AND source MATCHES
"f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
if
(
source MATCHES
"gemm_multiply_multiply_f8"
)
list
(
REMOVE_ITEM INST_TARGETS gfx900 gfx906 gfx908 gfx90a gfx1030 gfx1100 gfx1101 gfx1102 gfx1103 gfx1200 gfx1201
)
endif
()
endif
()
set
(
offload_targets
)
set
(
offload_targets
)
foreach
(
target IN LISTS INST_TARGETS
)
foreach
(
target IN LISTS INST_TARGETS
)
string
(
APPEND offload_targets
"--offload-arch=
${
target
}
"
)
string
(
APPEND offload_targets
"--offload-arch=
${
target
}
"
)
...
@@ -108,7 +132,7 @@ function(add_instance_library INSTANCE_NAME)
...
@@ -108,7 +132,7 @@ function(add_instance_library INSTANCE_NAME)
# flags to compress the library
# flags to compress the library
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
if
(
NOT WIN32 AND
${
hip_VERSION_FLAT
}
GREATER 600241132
)
message
(
"Adding --offload-compress flag for
${
INSTANCE_NAME
}
"
)
#
message("Adding --offload-compress flag for ${INSTANCE_NAME}")
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
target_compile_options
(
${
INSTANCE_NAME
}
PRIVATE --offload-compress
)
endif
()
endif
()
...
...
library/src/tensor_operation_instance/gpu/gemm_multiply_multiply/device_gemm_multiply_multiply_xdl_f8_f8_bf16/device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn.hpp
View file @
0475a327
...
@@ -36,12 +36,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
...
@@ -36,12 +36,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template
<
GemmSpecialization
GemmSpec
>
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Compute friendly
// Compute friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
@@ -58,17 +58,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
...
@@ -58,17 +58,18 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_comp_instances = std
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| ALayout| BLayout| DsLayout| ELayout|AData| BData| DsData| EData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | Type| Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//################################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
// Latency friendly
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
2
,
2
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
@@ -90,6 +91,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std:
...
@@ -90,6 +91,7 @@ using device_gemm_multiply_multiply_xdl_f8_f8_bf16_mk_nk_mn_mem_instances = std:
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
4
,
4
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
DeviceGemmMultiD_Xdl_CShuffle_V3
<
Row
,
Col
,
Tuple
<
Row
,
Col
>
,
Row
,
F8
,
F8
,
Tuple
<
F32
,
F32
>
,
BF16
,
F32
,
F32
,
PassThrough
,
PassThrough
,
MultiplyMultiply
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
S
<
8
,
8
,
1
>
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn.hpp
View file @
0475a327
...
@@ -62,12 +62,12 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
...
@@ -62,12 +62,12 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_comp_instances = std::tuple<
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
4
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
4
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
8
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
@@ -90,6 +90,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple<
...
@@ -90,6 +90,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_kn_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
4
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
64
,
128
,
16
,
4
,
16
,
16
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
32
,
4
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
4
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
16
,
128
,
128
,
16
,
8
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
16
,
8
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Row
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
8
,
8
,
16
,
16
,
1
,
4
,
S
<
16
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
8
,
8
,
0
,
S
<
16
,
16
,
1
>
,
S
<
0
,
2
,
1
>
,
S
<
0
,
2
,
1
>
,
1
,
16
,
8
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
library/src/tensor_operation_instance/gpu/gemm_universal/device_gemm_xdl_universal_f8_f8_bf16/device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn.hpp
View file @
0475a327
...
@@ -35,12 +35,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
...
@@ -35,12 +35,12 @@ static constexpr auto Interwave = BlockGemmPipelineScheduler::Interwave;
template
<
GemmSpecialization
GemmSpec
>
template
<
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Compute friendly
// Compute friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
256
,
256
,
64
,
16
,
16
,
32
,
32
,
4
,
4
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
128
,
128
,
16
,
16
,
32
,
32
,
2
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v4
,
F8
>
,
...
@@ -57,17 +57,18 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
...
@@ -57,17 +57,18 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
128
,
64
,
128
,
16
,
16
,
32
,
32
,
2
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
,
// DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
// DeviceGemm_Xdl_CShuffleV3< Row, Col, Row, F8, F8, BF16, F32, BF16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 64, 128, 128, 16, 16, 32, 32, 1, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3, F8>,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
64
,
64
,
128
,
16
,
16
,
32
,
32
,
1
,
1
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
32
,
1
,
8
>
,
8
,
BlockGemmPipelineScheduler
::
Intrawave
,
BlockGemmPipelineVersion
::
v3
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
template
<
BlockGemmPipelineScheduler
BlkGemmPipeSched
,
GemmSpecialization
GemmSpec
>
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
using
device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances
=
std
::
tuple
<
// clang-format off
// clang-format off
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| ALayout| BLayout| CLayout|AData| BData| CData| AccData| Cshuffle| A| B| C| GEMM| Block| MPer| NPer| KPer| AK1| BK1|MPer| NPer| MXdl| NXdl| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockLds| BBlockTransfer| BBlockTransfer| BBlockTransfer| BlockTransfer| BBlockTransfer| BBlockTransfer| BBlockLds| CShuffle| CShuffle| CBlockTransferClusterLengths| CBlockTransfer| Block-wiseGemm| Block-wiseGemm|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | Type| Type| Type| Type| Type| Elementwise| Elementwise| Elementwise|Specialization| Size| Block| Block| Block| | | XDL| XDL| Per| Per| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraM| ThreadCluster| ThreadCluster| SrcAccessOrder| SrcVectorDim| SrcScalar| DstScalar| AddExtraN| MXdlPerWave| NXdlPerWave| _MBlock_MXdlPerWave_MWaveMPerXdl| ScalarPerVector| Pipeline| Pipeline|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | Operation| Operation| Operation| | | | | | | | | | Wave| Wave| Lengths_K0_M_K1| ArrangeOrder| | | PerVector| PerVector_K1| | Lengths_K0_N_K1| ArrangeOrder| | | PerVector| PerVector_K1| | PerShuffle| PerShuffle| _NBlock_NXdlPerWave_NWaveNPerXdl| _NWaveNPerXdl| Scheduler| Verision|
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
//#########################| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
#if defined(__gfx94__) || defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
// Latency friendly
// Latency friendly
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
2
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
64
,
16
,
16
,
128
,
16
,
16
,
16
,
16
,
1
,
1
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
8
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
4
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v1
,
F8
>
,
...
@@ -97,6 +98,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
...
@@ -97,6 +98,7 @@ using device_gemm_xdl_universal_f8_f8_bf16_mk_nk_mn_mem_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
128
,
32
,
128
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
8
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
16
,
256
,
128
,
16
,
16
,
16
,
16
,
1
,
4
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
16
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
4
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
,
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
DeviceGemm_Xdl_CShuffleV3
<
Row
,
Col
,
Row
,
F8
,
F8
,
BF16
,
F32
,
BF16
,
PassThrough
,
PassThrough
,
PassThrough
,
GemmSpec
,
256
,
32
,
256
,
128
,
16
,
16
,
32
,
32
,
1
,
2
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
S
<
8
,
32
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
2
,
16
,
16
,
0
,
1
,
1
,
S
<
1
,
16
,
1
,
16
>
,
8
,
BlkGemmPipeSched
,
BlockGemmPipelineVersion
::
v2
,
F8
>
#endif
// clang-format on
// clang-format on
>
;
>
;
}
// namespace instance
}
// namespace instance
...
...
profiler/src/profile_gemm_universal.cpp
View file @
0475a327
...
@@ -101,7 +101,9 @@ int profile_gemm_universal(int argc, char* argv[])
...
@@ -101,7 +101,9 @@ int profile_gemm_universal(int argc, char* argv[])
using
F32
=
float
;
using
F32
=
float
;
using
F16
=
ck
::
half_t
;
using
F16
=
ck
::
half_t
;
using
BF16
=
ck
::
bhalf_t
;
using
BF16
=
ck
::
bhalf_t
;
using
F8
=
ck
::
f8_t
;
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
using
F8
=
ck
::
f8_t
;
#endif
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Row
=
ck
::
tensor_layout
::
gemm
::
RowMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
using
Col
=
ck
::
tensor_layout
::
gemm
::
ColumnMajor
;
...
@@ -162,6 +164,7 @@ int profile_gemm_universal(int argc, char* argv[])
...
@@ -162,6 +164,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
{
return
profile
(
F16
{},
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F16
{},
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F16_F8_F16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F16
{},
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F16
{},
F8
{},
F16
{},
F32
{},
F16
{},
Row
{},
Row
{},
Row
{});
...
@@ -178,6 +181,7 @@ int profile_gemm_universal(int argc, char* argv[])
...
@@ -178,6 +181,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
{
return
profile
(
F8
{},
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F16
{},
F16
{},
F32
{},
F16
{},
Row
{},
Col
{},
Row
{});
}
}
#endif
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
BF16_BF16_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
BF16
{},
BF16
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
BF16
{},
BF16
{},
BF16
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
...
@@ -194,6 +198,7 @@ int profile_gemm_universal(int argc, char* argv[])
...
@@ -194,6 +198,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
{
return
profile
(
BF16
{},
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Row
{},
Row
{});
return
profile
(
BF16
{},
BF16
{},
BF16
{},
F32
{},
BF16
{},
Col
{},
Row
{},
Row
{});
}
}
#if defined(CK_USE_FP8_ON_UNSUPPORTED_ARCH)
else
if
(
data_type
==
GemmDataType
::
F8_F8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
else
if
(
data_type
==
GemmDataType
::
F8_F8_BF16
&&
layout
==
GemmMatrixLayout
::
MK_KN_MN
)
{
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Row
{},
Row
{});
...
@@ -202,6 +207,7 @@ int profile_gemm_universal(int argc, char* argv[])
...
@@ -202,6 +207,7 @@ int profile_gemm_universal(int argc, char* argv[])
{
{
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
return
profile
(
F8
{},
F8
{},
F8
{},
F32
{},
BF16
{},
Row
{},
Col
{},
Row
{});
}
}
#endif
else
else
{
{
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
std
::
cout
<<
"this data_type & layout is not implemented"
<<
std
::
endl
;
...
...
test/CMakeLists.txt
View file @
0475a327
...
@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
...
@@ -210,3 +210,4 @@ if(SUPPORTED_GPU_TARGETS MATCHES "gfx942" AND CK_HIP_VERSION_MAJOR GREATER_EQUAL
add_subdirectory
(
smfmac_op
)
add_subdirectory
(
smfmac_op
)
endif
()
endif
()
add_subdirectory
(
position_embedding
)
add_subdirectory
(
position_embedding
)
add_subdirectory
(
scatter_gather
)
test/ck_tile/CMakeLists.txt
View file @
0475a327
add_subdirectory
(
image_to_column
)
add_subdirectory
(
image_to_column
)
add_subdirectory
(
gemm
)
Prev
1
…
9
10
11
12
13
14
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment