Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ef5e60f6
Commit
ef5e60f6
authored
Dec 11, 2024
by
illsilin
Browse files
Merge branch 'develop' into gfx950
parents
2cc0fa26
5e93fa9e
Changes
403
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1875 additions
and
66 deletions
+1875
-66
include/ck_tile/core/tensor/tile_window.hpp
include/ck_tile/core/tensor/tile_window.hpp
+72
-2
include/ck_tile/core/tensor/tile_window_linear.hpp
include/ck_tile/core/tensor/tile_window_linear.hpp
+142
-17
include/ck_tile/core/tensor/tile_window_utils.hpp
include/ck_tile/core/tensor/tile_window_utils.hpp
+54
-0
include/ck_tile/core/tensor/update_tile.hpp
include/ck_tile/core/tensor/update_tile.hpp
+53
-3
include/ck_tile/core/utility/amd_address_space.hpp
include/ck_tile/core/utility/amd_address_space.hpp
+37
-0
include/ck_tile/core/utility/static_counter.hpp
include/ck_tile/core/utility/static_counter.hpp
+116
-0
include/ck_tile/host.hpp
include/ck_tile/host.hpp
+2
-0
include/ck_tile/host/device_memory.hpp
include/ck_tile/host/device_memory.hpp
+35
-0
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+107
-6
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+103
-18
include/ck_tile/host/joinable_thread.hpp
include/ck_tile/host/joinable_thread.hpp
+27
-0
include/ck_tile/host/reference/reference_fused_moe.hpp
include/ck_tile/host/reference/reference_fused_moe.hpp
+196
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+112
-0
include/ck_tile/host/reference/reference_moe_sorting.hpp
include/ck_tile/host/reference/reference_moe_sorting.hpp
+24
-5
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+21
-2
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
.../pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
+30
-7
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
...ipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
+20
-6
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
.../ck_tile/ops/elementwise/unary_element_wise_operation.hpp
+99
-0
include/ck_tile/ops/flatmm.hpp
include/ck_tile/ops/flatmm.hpp
+10
-0
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
...ile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
+615
-0
No files found.
include/ck_tile/core/tensor/tile_window.hpp
View file @
ef5e60f6
...
...
@@ -292,12 +292,15 @@ struct tile_window_with_static_distribution
{
constexpr
auto
tile_dstr
=
TileDstr
{};
auto
dst_tensor
=
make_static_distributed_tensor
<
DataType
>
(
tile_dstr
);
load
(
dst_tensor
,
bool_constant
<
oob_conditional_check
>
{});
load
(
dst_tensor
,
number
<
i_access_unsupport_
>
{},
bool_constant
<
oob_conditional_check
>
{});
return
dst_tensor
;
}
template
<
typename
DistributedTensor
,
bool
oob_conditional_check
=
true
>
template
<
typename
DistributedTensor
,
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DistributedTensor
&
dst_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
Traits
=
load_store_traits
;
...
...
@@ -785,6 +788,73 @@ struct tile_window_with_static_distribution
});
}
template
<
index_t
i_access_unsupport_
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access_unsupport_
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
Traits
=
load_store_traits
;
using
vector_t
=
typename
Traits
::
vector_t
;
using
SFC_Ys
=
typename
Traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
static_for
<
0
,
NumCoord
,
1
>
{}([
&
](
auto
iCoord
)
{
/// TODO: use structure binding (to be captured later) if compiled in C++20
auto
window_adaptor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I0
];
auto
bottom_tensor_thread_coord
=
pre_computed_coords_
[
iCoord
][
I1
];
static_for
<
0
,
NumAccessPerCoord
,
1
>
{}([
&
](
auto
iCoordAccess
)
{
constexpr
auto
iAccess
=
number
<
iCoord
*
NumAccessPerCoord
+
iCoordAccess
>
{};
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
iAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
Traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
Traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
0
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
// move thread coordinate
if
constexpr
(
iCoordAccess
!=
(
NumAccessPerCoord
-
1
))
{
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_forward_step
(
iAccess
);
constexpr
auto
idx_diff_ps_ys
=
container_concat
(
generate_tuple
([
&
](
auto
)
{
return
number
<
0
>
{};
},
number
<
NDimP
>
{}),
idx_diff_ys
);
move_window_adaptor_and_bottom_tensor_thread_coordinate
(
window_adaptor_thread_coord
,
bottom_tensor_thread_coord
,
idx_diff_ps_ys
);
}
});
});
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
...
...
include/ck_tile/core/tensor/tile_window_linear.hpp
View file @
ef5e60f6
...
...
@@ -432,23 +432,38 @@ struct tile_window_linear
CK_TILE_DEVICE
static
constexpr
index_t
get_bottom_linear_offset
(
number
<
i_access
>
)
{
constexpr
auto
linear_coord
=
get_bottom_linear_coordinate
(
number
<
i_access
>
{});
// since this is linear offset, we assum bottom X tensor is always linear
constexpr
index_t
linear_offset
=
[
&
]()
{
constexpr
auto
x_idx_
=
linear_coord
;
constexpr
auto
x_len_
=
TileDstr
{}.
get_lengths
();
static_assert
(
x_idx_
.
size
()
==
x_len_
.
size
());
constexpr
index_t
x_dims_
=
x_idx_
.
size
();
index_t
cu_stride_
=
1
;
index_t
cu_offset_
=
0
;
static_for
<
0
,
x_dims_
,
1
>
{}([
&
](
auto
i_
)
{
auto
r_i_
=
number
<
x_dims_
-
i_
-
1
>
{};
cu_offset_
+=
x_idx_
[
r_i_
]
*
cu_stride_
;
cu_stride_
*=
x_len_
[
r_i_
];
});
return
cu_offset_
;
}();
return
linear_offset
;
constexpr
auto
is_pure_linear_tensor
=
reduce_on_sequence
(
LinearBottomDims
{},
multiplies
{},
number
<
1
>
{});
if
constexpr
(
is_pure_linear_tensor
)
{
// this case usually is a LDS window, everything is known at compile tile.
// we directly use BottomTensorView transform to compute the offset, in case padding
auto
bottom_tensor_coord
=
make_tensor_coordinate
(
BottomTensorView
{}.
get_tensor_descriptor
(),
linear_coord
);
return
bottom_tensor_coord
.
get_offset
();
}
else
{
// this case usually is a global window, where last dim can be linear
// we hack here, that use the original TileDstr to compute the linear offset
// ... hoping that there is no extra padding between other dims, which make sense
// since that would introduce runtime length (so can't use linear offset)
constexpr
index_t
linear_offset
=
[
&
]()
{
constexpr
auto
x_idx_
=
linear_coord
;
constexpr
auto
x_len_
=
TileDstr
{}.
get_lengths
();
static_assert
(
x_idx_
.
size
()
==
x_len_
.
size
());
constexpr
index_t
x_dims_
=
x_idx_
.
size
();
index_t
cu_stride_
=
1
;
index_t
cu_offset_
=
0
;
static_for
<
0
,
x_dims_
,
1
>
{}([
&
](
auto
i_
)
{
auto
r_i_
=
number
<
x_dims_
-
i_
-
1
>
{};
cu_offset_
+=
x_idx_
[
r_i_
]
*
cu_stride_
;
cu_stride_
*=
x_len_
[
r_i_
];
});
return
cu_offset_
;
}();
return
linear_offset
;
}
}
CK_TILE_DEVICE
constexpr
auto
get_num_of_access
()
const
{
return
traits
::
NumAccess
;
}
...
...
@@ -509,6 +524,64 @@ struct tile_window_linear
return
dst_tensor
;
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
auto
load
(
DstTile
&
dst_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// auto dst_tensor = make_static_distributed_tensor<DataType>(tile_dstr);
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
// read from bottom tensor
const
vector_t
vec_value
=
get_bottom_tensor_view
().
template
get_vectorized_elements
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
bool_constant
<
oob_conditional_check
>
{});
#if 1
// data index [y0, y1, ...]
constexpr
auto
idx_diff_ys
=
SFC_Ys
::
get_index
(
IAccess
);
// write into distributed tensor
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_diff_ys
[
jj
]
+
j
)
:
idx_diff_ys
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
dst_tensor
.
get_thread_buffer
().
template
at
<
d
>()
=
vec_value
.
template
get_as
<
DataType
>()[
j
];
});
#else
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys_start
);
static_assert
(
d
%
traits
::
ScalarPerVector
==
0
);
dst_tensor
.
get_thread_buffer
().
template
get_as
<
vector_t
>()(
number
<
d
/
traits
::
ScalarPerVector
>
{})
=
bit_cast
<
vector_t
>
(
vec_value
);
#endif
};
WINDOW_DISPATCH_ISSUE
();
return
dst_tensor
;
}
template
<
typename
DstTile
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
...
...
@@ -849,6 +922,58 @@ struct tile_window_linear
WINDOW_DISPATCH_ISSUE
();
}
template
<
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_raw
(
const
static_distributed_tensor
<
DataType
,
TileDstr
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
const
{
using
vector_t
=
typename
traits
::
vector_t
;
using
SFC_Ys
=
typename
traits
::
SFC_Ys
;
constexpr
auto
tile_dstr
=
TileDstr
{};
// loop over thread tensor space [y0, y1, ...]
auto
issue
=
[
&
](
auto
i_access_
)
{
constexpr
auto
IAccess
=
number
<
i_access_
>
{};
constexpr
auto
non_linear_id
=
number
<
AccessMap_NonLinear
{}[
IAccess
]
>
{};
auto
bottom_tensor_thread_coord
=
cached_coords_
[
non_linear_id
];
constexpr
auto
linear_offset
=
get_bottom_linear_offset
(
IAccess
);
auto
bottom_tensor_flag
=
cached_flags_
[
IAccess
];
// data index [y0, y1, ...]
constexpr
auto
idx_ys_start
=
SFC_Ys
::
get_index
(
IAccess
);
// read from distributed tensor
vector_t
vec_value
;
static_for
<
0
,
traits
::
ScalarPerVector
,
1
>
{}([
&
](
auto
j
)
{
constexpr
auto
idx_ys
=
generate_tuple
(
[
&
](
auto
jj
)
{
return
jj
==
traits
::
VectorDimY
?
(
idx_ys_start
[
jj
]
+
j
)
:
idx_ys_start
[
jj
];
},
number
<
NDimY
>
{});
constexpr
index_t
d
=
tile_dstr
.
get_ys_to_d_descriptor
().
calculate_offset
(
idx_ys
);
vec_value
.
template
get_as
<
DataType
>()(
j
)
=
dstr_tensor
.
get_thread_buffer
().
template
at
<
d
>();
});
// write into bottom tensor
get_bottom_tensor_view
().
template
update_vectorized_elements_raw
<
vector_t
>(
bottom_tensor_thread_coord
,
linear_offset
,
bottom_tensor_flag
,
vec_value
,
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
};
WINDOW_DISPATCH_ISSUE
();
}
// move thread's botom tensor coordiante
// [x0', x1', ... ] ==> [offset]
// also move window-origin
...
...
include/ck_tile/core/tensor/tile_window_utils.hpp
0 → 100644
View file @
ef5e60f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "ck_tile/core/arch/arch.hpp"
#include "ck_tile/core/arch/utility.hpp"
#include "ck_tile/core/algorithm/space_filling_curve.hpp"
#include "ck_tile/core/config.hpp"
#include "ck_tile/core/container/array.hpp"
#include "ck_tile/core/container/sequence.hpp"
#include "ck_tile/core/container/tuple.hpp"
#include "ck_tile/core/container/container_helper.hpp"
#include "ck_tile/core/tensor/static_distributed_tensor.hpp"
#include "ck_tile/core/tensor/tensor_adaptor.hpp"
#include "ck_tile/core/tensor/tile_distribution.hpp"
#include "ck_tile/core/utility/functional.hpp"
#include "ck_tile/core/utility/type_traits.hpp"
#pragma once
namespace
ck_tile
{
// input a lds store tile, extract some information from it
// used to set m0 value for gfx9 serious
template
<
typename
LdsTileWindow_
>
CK_TILE_DEVICE
auto
get_async_store_smem_info
(
LdsTileWindow_
&&
lds_tile
)
{
using
LdsTileWindow
=
remove_cvref_t
<
LdsTileWindow_
>
;
using
LdsDataType
=
typename
LdsTileWindow
::
DataType
;
// issues * warps * lanes
static_assert
(
LdsTileWindow
::
get_num_of_dimension
()
==
3
);
// TODO: hard coded
const
index_t
size_per_buf
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
);
const
index_t
size_per_wave
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
0
>
{},
number
<
1
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
size_per_issue
=
lds_tile
.
get_bottom_tensor_view
().
get_tensor_descriptor
().
calculate_offset
(
make_tuple
(
number
<
1
>
{},
number
<
0
>
{},
number
<
0
>
{}))
*
sizeof
(
LdsDataType
)
-
size_per_buf
;
const
index_t
m0_init_value
=
size_per_buf
+
size_per_wave
*
get_warp_id
();
return
make_tuple
(
m0_init_value
,
size_per_issue
);
}
}
// namespace ck_tile
include/ck_tile/core/tensor/update_tile.hpp
View file @
ef5e60f6
...
...
@@ -41,15 +41,65 @@ template <typename BottomTensorView_,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
>
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
>
CK_TILE_DEVICE
void
update_tile
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
)
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{})
{
tile_window
.
update
(
dstr_tensor
);
tile_window
.
update
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
index_t
NumCoord
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
void
update_tile_raw
(
tile_window_with_static_distribution
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
NumCoord
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
template
<
typename
BottomTensorView_
,
typename
WindowLengths_
,
typename
TileDistribution_
,
typename
LinearBottomDims_
,
typename
DataType_
,
index_t
i_access
=
-
1
,
bool
oob_conditional_check
=
true
,
bool
pre_nop
=
false
>
CK_TILE_DEVICE
auto
update_tile_raw
(
tile_window_linear
<
BottomTensorView_
,
WindowLengths_
,
TileDistribution_
,
LinearBottomDims_
>&
tile_window
,
const
static_distributed_tensor
<
DataType_
,
TileDistribution_
>&
dstr_tensor
,
number
<
i_access
>
=
{},
bool_constant
<
oob_conditional_check
>
=
{},
bool_constant
<
pre_nop
>
=
{})
{
tile_window
.
update_raw
(
dstr_tensor
,
number
<
i_access
>
{},
bool_constant
<
oob_conditional_check
>
{},
bool_constant
<
pre_nop
>
{});
}
}
// namespace ck_tile
include/ck_tile/core/utility/amd_address_space.hpp
0 → 100644
View file @
ef5e60f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
// Address Space for AMDGCN
// https://llvm.org/docs/AMDGPUUsage.html#address-space
namespace
ck_tile
{
#define CK_CONSTANT_ADDRESS_SPACE __attribute__((address_space(4)))
template
<
typename
T
>
__device__
T
*
cast_pointer_to_generic_address_space
(
T
CK_CONSTANT_ADDRESS_SPACE
*
p
)
{
// cast a pointer in "Constant" address space (4) to "Generic" address space (0)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
template
<
typename
T
>
__host__
__device__
T
CK_CONSTANT_ADDRESS_SPACE
*
cast_pointer_to_constant_address_space
(
T
*
p
)
{
// cast a pointer in "Generic" address space (0) to "Constant" address space (4)
// only c-style pointer cast seems be able to be compiled
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Wold-style-cast"
return
(
T
CK_CONSTANT_ADDRESS_SPACE
*
)
p
;
// NOLINT(old-style-cast)
#pragma clang diagnostic pop
}
}
// namespace ck_tile
include/ck_tile/core/utility/static_counter.hpp
0 → 100644
View file @
ef5e60f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/config.hpp"
namespace
ck_tile
{
template
<
typename
Context
,
index_t
Start
=
0
,
index_t
Step
=
1
>
struct
static_counter
{
public:
template
<
typename
Unique
>
static
constexpr
index_t
next
()
{
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
next
()
{
struct
Unique
{
};
return
next
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
typename
Unique
>
static
constexpr
index_t
current
()
{
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
template
<
unsigned
long
long
>
static
constexpr
index_t
current
()
{
struct
Unique
{
};
return
current
<
Unique
>
(
0
)
*
Step
+
Start
;
}
private:
template
<
index_t
I
>
struct
slot
{
_Pragma
(
"GCC diagnostic push"
);
_Pragma
(
"GCC diagnostic ignored
\"
-Wundefined-internal
\"
"
);
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
);
_Pragma
(
"GCC diagnostic pop"
);
};
template
<
index_t
I
>
struct
allocate_slot
{
friend
constexpr
bool
slot_allocated
(
slot
<
I
>
)
{
return
true
;
}
enum
{
value
=
I
};
};
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
0
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
next
(
index_t
)
{
return
next
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will define slot_allocated(slot<I>) via
// allocate_slot<I>.
template
<
typename
Unique
,
index_t
I
=
0
>
static
constexpr
index_t
next
(
double
)
{
return
allocate_slot
<
I
>::
value
;
}
// If slot_allocated(slot<I>) has NOT been defined, then SFINAE will keep this function out of
// the overload set...
template
<
typename
Unique
,
index_t
I
=
Start
,
bool
=
slot_allocated
(
slot
<
I
>())
>
static
constexpr
index_t
current
(
index_t
)
{
return
current
<
Unique
,
I
+
1
>
(
0
);
}
// ...And this function will be used, instead, which will return the current counter, or assert
// in case next() hasn't been called yet.
template
<
typename
Unique
,
index_t
I
=
Start
>
static
constexpr
index_t
current
(
double
)
{
static_assert
(
I
!=
0
,
"You must invoke next() first"
);
return
I
-
1
;
}
};
namespace
impl
{
template
<
int
I
>
struct
static_counter_uniq_
;
}
#define MAKE_SC() \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>> {}
#define MAKE_SC_WITH(start_, step_) \
ck_tile::static_counter<ck_tile::impl::static_counter_uniq_<__COUNTER__>, start_, step_> {}
#define NEXT_SC(c_) c_.next<__COUNTER__>()
#define NEXT_SCI(c_, static_i_) c_.next<__COUNTER__ + static_i_>()
// Usage:
// constexpr auto c = MAKE_SC()
// NEXT_SC(c) // -> constexpr 0
// NEXT_SC(c) // -> constexpr 1
// NEXT_SC(c) // -> constexpr 2
}
// namespace ck_tile
include/ck_tile/host.hpp
View file @
ef5e60f6
...
...
@@ -11,6 +11,7 @@
#include "ck_tile/host/fill.hpp"
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/kernel_launch.hpp"
#include "ck_tile/host/ranges.hpp"
#include "ck_tile/host/reference/reference_batched_dropout.hpp"
...
...
@@ -20,6 +21,7 @@
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp"
#include "ck_tile/host/reference/reference_im2col.hpp"
#include "ck_tile/host/reference/reference_layernorm2d_fwd.hpp"
...
...
include/ck_tile/host/device_memory.hpp
View file @
ef5e60f6
...
...
@@ -7,6 +7,7 @@
#include <stdint.h>
#include <stdexcept>
#include "ck_tile/host/hip_check_error.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
T
>
...
...
@@ -36,6 +37,19 @@ struct DeviceMem
mpDeviceBuf
=
nullptr
;
}
}
template
<
typename
T
>
DeviceMem
(
const
HostTensor
<
T
>&
t
)
:
mMemSize
(
t
.
get_element_space_size_in_bytes
())
{
if
(
mMemSize
!=
0
)
{
HIP_CHECK_ERROR
(
hipMalloc
(
static_cast
<
void
**>
(
&
mpDeviceBuf
),
mMemSize
));
}
else
{
mpDeviceBuf
=
nullptr
;
}
ToDevice
(
t
.
data
());
}
void
Realloc
(
std
::
size_t
mem_size
)
{
if
(
mpDeviceBuf
)
...
...
@@ -92,6 +106,27 @@ struct DeviceMem
HIP_CHECK_ERROR
(
hipMemcpy
(
p
,
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
}
// construct a host tensor with type T
template
<
typename
T
>
HostTensor
<
T
>
ToHost
(
std
::
size_t
cpySize
)
{
// TODO: host tensor could be slightly larger than the device tensor
// we just copy all data from GPU buffer
std
::
size_t
host_elements
=
(
cpySize
+
sizeof
(
T
)
-
1
)
/
sizeof
(
T
);
HostTensor
<
T
>
h_
({
host_elements
});
if
(
mpDeviceBuf
)
{
HIP_CHECK_ERROR
(
hipMemcpy
(
h_
.
data
(),
mpDeviceBuf
,
cpySize
,
hipMemcpyDeviceToHost
));
}
return
h_
;
}
template
<
typename
T
>
HostTensor
<
T
>
ToHost
()
{
return
ToHost
<
T
>
(
mMemSize
);
}
void
SetZero
()
const
{
if
(
mpDeviceBuf
)
...
...
include/ck_tile/host/fill.hpp
View file @
ef5e60f6
...
...
@@ -13,6 +13,7 @@
#include <unordered_set>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
namespace
ck_tile
{
...
...
@@ -22,13 +23,44 @@ struct FillUniformDistribution
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
// ATTENTION: threaded does not guarantee the distribution between thread
bool
threaded
=
false
;
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
if
(
threaded
)
{
uint32_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
auto
total
=
static_cast
<
std
::
size_t
>
(
std
::
distance
(
first
,
last
));
auto
work_per_thread
=
static_cast
<
std
::
size_t
>
((
total
+
num_thread
-
1
)
/
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
total
);
auto
thread_f
=
[
this
,
total
,
iw_begin
,
iw_end
,
&
first
]
{
if
(
iw_begin
>
total
||
iw_end
>
total
)
return
;
// need to make each thread unique, add an offset to current seed
std
::
mt19937
gen
(
seed_
.
has_value
()
?
(
*
seed_
+
iw_begin
)
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
+
iw_begin
,
first
+
iw_end
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
};
threads
[
it
]
=
joinable_thread
(
thread_f
);
}
}
else
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
template
<
typename
ForwardRange
>
...
...
@@ -115,13 +147,44 @@ struct FillNormalDistribution
float
mean_
{
0.
f
};
float
variance_
{
1.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
// ATTENTION: threaded does not guarantee the distribution between thread
bool
threaded
=
false
;
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
if
(
threaded
)
{
uint32_t
num_thread
=
std
::
thread
::
hardware_concurrency
();
auto
total
=
static_cast
<
std
::
size_t
>
(
std
::
distance
(
first
,
last
));
auto
work_per_thread
=
static_cast
<
std
::
size_t
>
((
total
+
num_thread
-
1
)
/
num_thread
);
std
::
vector
<
joinable_thread
>
threads
(
num_thread
);
for
(
std
::
size_t
it
=
0
;
it
<
num_thread
;
++
it
)
{
std
::
size_t
iw_begin
=
it
*
work_per_thread
;
std
::
size_t
iw_end
=
std
::
min
((
it
+
1
)
*
work_per_thread
,
total
);
auto
thread_f
=
[
this
,
total
,
iw_begin
,
iw_end
,
&
first
]
{
if
(
iw_begin
>
total
||
iw_end
>
total
)
return
;
// need to make each thread unique, add an offset to current seed
std
::
mt19937
gen
(
seed_
.
has_value
()
?
(
*
seed_
+
iw_begin
)
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
+
iw_begin
,
first
+
iw_end
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
};
threads
[
it
]
=
joinable_thread
(
thread_f
);
}
}
else
{
std
::
mt19937
gen
(
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}());
std
::
normal_distribution
<
float
>
dis
(
mean_
,
std
::
sqrt
(
variance_
));
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
]()
{
return
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
});
}
}
template
<
typename
ForwardRange
>
...
...
@@ -235,6 +298,44 @@ struct FillMonotonicSeq
}
};
template
<
typename
T
,
bool
IsAscending
=
true
>
struct
FillStepRange
{
float
start_value_
{
0
};
float
end_value_
{
3
};
float
step_
{
1
};
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
const
{
std
::
generate
(
first
,
last
,
[
=
,
n
=
start_value_
]()
mutable
{
auto
tmp
=
n
;
n
+=
step_
;
if
constexpr
(
IsAscending
)
{
if
(
n
>
end_value_
)
n
=
start_value_
;
}
else
{
if
(
n
<
end_value_
)
n
=
start_value_
;
}
return
type_convert
<
T
>
(
tmp
);
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
const
->
std
::
void_t
<
decltype
(
std
::
declval
<
const
FillStepRange
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
};
template
<
typename
T
>
struct
FillConstant
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
ef5e60f6
...
...
@@ -8,12 +8,13 @@
#include <iostream>
#include <iomanip>
#include <numeric>
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include <fstream>
#include "ck_tile/core.hpp"
#include "ck_tile/host/joinable_thread.hpp"
#include "ck_tile/host/ranges.hpp"
namespace
ck_tile
{
...
...
@@ -213,23 +214,6 @@ CK_TILE_HOST HostTensorDescriptor transpose_host_tensor_descriptor_given_new2old
return
HostTensorDescriptor
(
new_lengths
,
new_strides
);
}
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
template
<
typename
F
,
typename
...
Xs
>
struct
ParallelTensorFunctor
{
...
...
@@ -590,6 +574,107 @@ struct HostTensor
size
()
*
FromSize
/
ToSize
};
}
friend
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
HostTensor
<
T
>&
t
)
{
os
<<
t
.
mDesc
;
os
<<
"["
;
for
(
typename
Data
::
size_type
idx
=
0
;
idx
<
t
.
mData
.
size
();
++
idx
)
{
if
(
0
<
idx
)
{
os
<<
", "
;
}
if
constexpr
(
std
::
is_same_v
<
T
,
bf16_t
>
||
std
::
is_same_v
<
T
,
fp16_t
>
)
{
os
<<
type_convert
<
float
>
(
t
.
mData
[
idx
])
<<
" #### "
;
}
else
{
os
<<
t
.
mData
[
idx
];
}
}
os
<<
"]"
;
return
os
;
}
// read data from a file, as dtype
// the file could dumped from torch as (targeting tensor is t here)
// numpy.savetxt("f.txt", t.view(-1).numpy())
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy()) # from cuda to cpu to save
// numpy.savetxt("f.txt", t.cpu().view(-1).numpy(), fmt="%d") # save as int
// will output f.txt, each line is a value
// dtype=float or int, internally will cast to real type
void
loadtxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ifstream
file
(
file_name
);
if
(
file
.
is_open
())
{
std
::
string
line
;
index_t
cnt
=
0
;
while
(
std
::
getline
(
file
,
line
))
{
if
(
cnt
>=
static_cast
<
index_t
>
(
mData
.
size
()))
{
throw
std
::
runtime_error
(
std
::
string
(
"data read from file:"
)
+
file_name
+
" is too big"
);
}
if
(
dtype
==
"float"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stof
(
line
));
}
else
if
(
dtype
==
"int"
||
dtype
==
"int32"
)
{
mData
[
cnt
]
=
type_convert
<
T
>
(
std
::
stoi
(
line
));
}
cnt
++
;
}
file
.
close
();
if
(
cnt
<
static_cast
<
index_t
>
(
mData
.
size
()))
{
std
::
cerr
<<
"Warning! reading from file:"
<<
file_name
<<
", does not match the size of this tensor"
<<
std
::
endl
;
}
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
// can save to a txt file and read from torch as:
// torch.from_numpy(np.loadtxt('f.txt', dtype=np.int32/np.float32...)).view([...]).contiguous()
void
savetxt
(
std
::
string
file_name
,
std
::
string
dtype
=
"float"
)
{
std
::
ofstream
file
(
file_name
);
if
(
file
.
is_open
())
{
for
(
auto
&
itm
:
mData
)
{
if
(
dtype
==
"float"
)
file
<<
type_convert
<
float
>
(
itm
)
<<
std
::
endl
;
else
if
(
dtype
==
"int"
)
file
<<
type_convert
<
int
>
(
itm
)
<<
std
::
endl
;
else
// TODO: we didn't implement operator<< for all custom
// data types, here fall back to float in case compile error
file
<<
type_convert
<
float
>
(
itm
)
<<
std
::
endl
;
}
file
.
close
();
}
else
{
// Print an error message to the standard error
// stream if the file cannot be opened.
throw
std
::
runtime_error
(
std
::
string
(
"unable to open file:"
)
+
file_name
);
}
}
Descriptor
mDesc
;
Data
mData
;
};
...
...
include/ck_tile/host/joinable_thread.hpp
0 → 100644
View file @
ef5e60f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <thread>
#include <utility>
namespace
ck_tile
{
struct
joinable_thread
:
std
::
thread
{
template
<
typename
...
Xs
>
joinable_thread
(
Xs
&&
...
xs
)
:
std
::
thread
(
std
::
forward
<
Xs
>
(
xs
)...)
{
}
joinable_thread
(
joinable_thread
&&
)
=
default
;
joinable_thread
&
operator
=
(
joinable_thread
&&
)
=
default
;
~
joinable_thread
()
{
if
(
this
->
joinable
())
this
->
join
();
}
};
}
// namespace ck_tile
include/ck_tile/host/reference/reference_fused_moe.hpp
0 → 100644
View file @
ef5e60f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
// [indexing implementation-1]
// using M_a as constexpr block_size to partition all tokens into different slices
// each slice map to one expert, and one expert can have multiple slices
// e.g. num_experts = 6, topk=3, M_a = 4, input_tokens = 5
// before sort, topk_ids is : [[0, 3, 5], [2, 3, 5], [1, 3, 5], [1, 2, 3], [1, 3, 5]]
// tok-0 tok-1 tok-2 tok-3 tok-4
// topk_weight is : [[a, b, c], [d, e, f], [g, h, i], [j, k, l], [m, n, o]] (some float
// number)
//
// token_id_per_expert is : [[0], [2, 3, 4], [1, 3], [0, 1, 2, 3, 4], [], [0, 1, 2, 5]]
// (only for reference) exp-0 exp-1 exp-2 exp-3 exp-4 exp-5
// weight_id_per_expert is: [[a], [g, j, m], [d, k], [b, e, h, l, n], [], [c, f, i, o]]
//
// max_num_tokens_padded : topk * input_tokens + num_experts * (M_a - 1)
// max_num_tokens_padded : topk * input_tokens + num_experts * M_a - topk (updated)
// * this could be larger than actual, since actual tokens are on GPU
//
// sorted_token_ids_ptr : [0, 6, 6, 6, 2, 3, 4, 6, 1, 3, 6, 6, 0, 1, 2, 3, 4, 6, 6, 6, 6, 6, 6, 6,
// 0, 1, 2, 5]
// |- exp-0 -|- exp-1 -|- exp-2 -|- exp-3 -|- exp-4
// -|- exp-5 -|
// sorted_weight_ptr : [a, *, *, *, g, j, m, *, d, k, *, *, b, e, h, l, n, *, *, *, *, *, *, *,
// c, f, i, o]
//
// * length is max_num_tokens_padded, actual size is num_tokens_post_padded_ptr
//
// sorted_expert_ids_ptr : [0, 1, 2, 3, 3, 4, 5]
// * length is (max_num_tokens_padded + block_size - 1) / block_size
///
// num_tokens_post_padded_ptr : [28]
// num_sorted_tiles_ptr : [7]
template
<
typename
AccDataType
,
// you only need to explcitly set this one
typename
Activation
,
// ck_tile::element_wise::Gelu
typename
ADataType
,
typename
GDataType
,
typename
DDataType
,
typename
ODataType
,
typename
AScaleDataType
,
typename
GScaleDataType
,
typename
DScaleDataType
,
typename
YSmoothScaleDataType
,
typename
TopkWeightDataType
,
typename
IndexDataType
>
void
reference_fused_moe
(
const
ck_tile
::
HostTensor
<
ADataType
>&
a_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
GDataType
>&
g_host
,
// [experts, interme_size_0, hidden_size]
const
ck_tile
::
HostTensor
<
DDataType
>&
d_host
,
// [experts, hidden_size, interme_size_1]
const
ck_tile
::
HostTensor
<
AScaleDataType
>&
sa_host
,
// [tokens, 1],
const
ck_tile
::
HostTensor
<
GScaleDataType
>&
sg_host
,
// [experts, 1, interme_size_0]
const
ck_tile
::
HostTensor
<
DScaleDataType
>&
sd_host
,
// [experts, 1, hidden_size],
const
ck_tile
::
HostTensor
<
YSmoothScaleDataType
>&
sy_host
,
// [experts, 1, interme_size_0]
ck_tile
::
HostTensor
<
ODataType
>&
o_host
,
// [tokens, hidden_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_token_ids_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
TopkWeightDataType
>&
sorted_weight_host
,
// [max_num_tokens_padded]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
sorted_expert_ids_host
,
// [(max_num_tokens_padded + block_size - 1) / block_size]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
num_sorted_tiles_host
,
// [1]
const
ck_tile
::
HostTensor
<
IndexDataType
>&
token_ids_host
,
// [tokens, topk] --> ugly!!! remove in the future
ck_tile
::
index_t
block_m
,
ck_tile
::
index_t
tokens
,
ck_tile
::
index_t
experts
,
ck_tile
::
index_t
hidden_size
,
ck_tile
::
index_t
intermediate_size
,
// this size is for gate/up
ck_tile
::
index_t
topk
,
ck_tile
::
index_t
gate_only
)
{
assert
(
sorted_token_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_weight_host
.
get_num_of_dimension
()
==
1
);
assert
(
sorted_expert_ids_host
.
get_num_of_dimension
()
==
1
);
assert
(
num_sorted_tiles_host
.
get_element_size
()
==
1
);
ck_tile
::
index_t
num_sorted_tiles
=
num_sorted_tiles_host
.
mData
[
0
]
/
block_m
;
ck_tile
::
index_t
intermediate_size_0
=
intermediate_size
;
ck_tile
::
index_t
intermediate_size_1
=
intermediate_size
/
(
gate_only
?
1
:
2
);
// TODO: better remove this in the future, or modify the token_id value
auto
get_topk_id
=
[
&
](
ck_tile
::
index_t
token_id_
,
ck_tile
::
index_t
expert_id_
)
{
for
(
ck_tile
::
index_t
i_
=
0
;
i_
<
topk
;
i_
++
)
{
if
(
token_ids_host
(
token_id_
,
i_
)
==
expert_id_
)
return
i_
;
}
throw
std
::
runtime_error
(
"not correct token/expert pair
\n
"
);
return
-
1
;
// TODO: not correct!!
};
ck_tile
::
HostTensor
<
AccDataType
>
out_topk_tokens
({
tokens
,
topk
,
hidden_size
});
int
max_num_tokens_padded
=
topk
*
tokens
+
experts
*
block_m
-
topk
;
// assert();
auto
f
=
[
&
](
auto
i_flatten
)
{
ck_tile
::
index_t
i_tile
=
i_flatten
/
block_m
;
if
(
i_tile
>=
num_sorted_tiles
)
return
;
ck_tile
::
index_t
i_expert
=
sorted_expert_ids_host
.
mData
[
i_tile
];
ck_tile
::
index_t
i_token
=
sorted_token_ids_host
.
mData
[
i_flatten
];
if
(
i_token
>=
tokens
)
return
;
ck_tile
::
index_t
i_topk
=
get_topk_id
(
i_token
,
i_expert
);
// TODO: ugly
auto
weight
=
sorted_weight_host
.
mData
[
i_flatten
];
ck_tile
::
HostTensor
<
AccDataType
>
acc_0
({
1
,
intermediate_size_0
});
// first gemm
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_0
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
hidden_size
;
i_k
++
)
{
acc
+=
type_convert
<
AccDataType
>
(
a_host
(
i_token
,
i_k
))
*
type_convert
<
AccDataType
>
(
g_host
(
i_expert
,
i_n
,
i_k
));
}
acc_0
(
0
,
i_n
)
=
acc
;
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, acc);
}
ck_tile
::
HostTensor
<
AccDataType
>
y
({
1
,
intermediate_size_1
});
if
(
gate_only
)
{
if
(
intermediate_size_1
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
Activation
{}(
y
(
0
,
i_n
),
acc_0
(
0
,
i_n
));
// printf("ie:%2d, it:%3d, in:%d, %f\n", i_expert, i_token, i_n, y(0, i_n));
}
}
else
{
if
(
intermediate_size_1
*
2
!=
intermediate_size_0
)
throw
std
::
runtime_error
(
"intermediate_size not correct, 0:"
+
std
::
to_string
(
intermediate_size_0
)
+
", 1:"
+
std
::
to_string
(
intermediate_size_1
));
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
intermediate_size_1
;
i_n
++
)
{
AccDataType
tmp
;
Activation
{}(
tmp
,
acc_0
(
0
,
i_n
));
y
(
0
,
i_n
)
=
tmp
*
acc_0
(
0
,
i_n
+
intermediate_size_1
);
// TODO: elementwise mul
}
}
// second gemm, loop along gemm-n
ck_tile
::
HostTensor
<
AccDataType
>
acc_1
({
1
,
hidden_size
});
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
static_cast
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_k
=
0
;
i_k
<
intermediate_size_1
;
i_k
++
)
{
acc
+=
y
(
0
,
i_k
)
*
type_convert
<
AccDataType
>
(
d_host
(
i_expert
,
i_n
,
i_k
));
}
acc_1
(
0
,
i_n
)
=
acc
*
weight
;
// multiple weight here
}
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
out_topk_tokens
(
i_token
,
i_topk
,
i_n
)
=
acc_1
(
0
,
i_n
);
}
};
// make_ParallelTensorFunctor(f, max_num_tokens_padded)(std::thread::hardware_concurrency());
make_ParallelTensorFunctor
(
f
,
max_num_tokens_padded
)(
1
);
// reduce
auto
r
=
[
&
](
auto
i_token
)
{
for
(
ck_tile
::
index_t
i_n
=
0
;
i_n
<
hidden_size
;
i_n
++
)
{
AccDataType
acc
=
type_convert
<
AccDataType
>
(
0
);
for
(
ck_tile
::
index_t
i_topk
=
0
;
i_topk
<
topk
;
i_topk
++
)
{
acc
+=
out_topk_tokens
(
i_token
,
i_topk
,
i_n
);
}
o_host
(
i_token
,
i_n
)
=
type_convert
<
ODataType
>
(
acc
);
}
};
make_ParallelTensorFunctor
(
r
,
tokens
)(
std
::
thread
::
hardware_concurrency
());
(
void
)
num_sorted_tiles_host
;
(
void
)
sa_host
;
(
void
)
sg_host
;
(
void
)
sd_host
;
(
void
)
sy_host
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
ef5e60f6
...
...
@@ -183,4 +183,116 @@ void reference_gemm_gpu(DeviceMem& a_device,
return
;
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
>
void
reference_batched_gemm_gpu
(
DeviceMem
&
a_device
,
DeviceMem
&
b_device
,
DeviceMem
&
c_device
,
index_t
M
,
index_t
N
,
index_t
K
,
index_t
stride_a
,
index_t
stride_b
,
index_t
stride_c
,
index_t
batch_stride_A
,
index_t
batch_stride_B
,
index_t
batch_stride_C
,
index_t
batch_count
)
{
ADataType
*
d_A
;
BDataType
*
d_B
;
CDataType
*
d_C
;
hipError_t
errA
=
hipMalloc
(
&
d_A
,
batch_count
*
M
*
K
*
sizeof
(
ADataType
));
hipError_t
errB
=
hipMalloc
(
&
d_B
,
batch_count
*
N
*
K
*
sizeof
(
BDataType
));
hipError_t
errC
=
hipMalloc
(
&
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
));
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for A: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for B: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
return
;
// Early exit on error
}
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error allocating device memory for C: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
return
;
// Early exit on error
}
errA
=
hipMemcpy
(
d_A
,
a_device
.
GetDeviceBuffer
(),
batch_count
*
M
*
K
*
sizeof
(
ADataType
),
hipMemcpyHostToDevice
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying A to device: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipMemcpy
(
d_B
,
b_device
.
GetDeviceBuffer
(),
batch_count
*
N
*
K
*
sizeof
(
BDataType
),
hipMemcpyHostToDevice
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying B to device: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
int
totalElements
=
M
*
N
;
int
numThreadsPerBlock
=
256
;
// Common choice for threads per block
int
numBlocks
=
(
totalElements
+
numThreadsPerBlock
-
1
)
/
numThreadsPerBlock
;
for
(
index_t
batch_id
=
0
;
batch_id
<
batch_count
;
++
batch_id
)
{
ADataType
*
d_ATemp
=
d_A
+
batch_id
*
batch_stride_A
;
BDataType
*
d_BTemp
=
d_B
+
batch_id
*
batch_stride_B
;
CDataType
*
d_CTemp
=
d_C
+
batch_id
*
batch_stride_C
;
naive_gemm_kernel
<
ADataType
,
BDataType
,
AccDataType
,
CDataType
,
LayoutA
,
LayoutB
,
LayoutC
>
<<<
numBlocks
,
numThreadsPerBlock
>>>
(
d_ATemp
,
d_BTemp
,
d_CTemp
,
M
,
N
,
K
,
stride_a
,
stride_b
,
stride_c
);
}
errC
=
hipMemcpy
(
c_device
.
GetDeviceBuffer
(),
d_C
,
batch_count
*
M
*
N
*
sizeof
(
CDataType
),
hipMemcpyDeviceToHost
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error copying C to device: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
errA
=
hipFree
(
d_A
);
if
(
errA
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the A memory: "
<<
hipGetErrorString
(
errA
)
<<
std
::
endl
;
}
errB
=
hipFree
(
d_B
);
if
(
errB
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the B memory: "
<<
hipGetErrorString
(
errB
)
<<
std
::
endl
;
}
errC
=
hipFree
(
d_C
);
if
(
errC
!=
hipSuccess
)
{
std
::
cerr
<<
"Error free the C memory: "
<<
hipGetErrorString
(
errC
)
<<
std
::
endl
;
}
return
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_moe_sorting.hpp
View file @
ef5e60f6
...
...
@@ -8,6 +8,9 @@
namespace
ck_tile
{
#define MOE_SORTING_MOCK_ID(token_id_, topk_id_) \
static_cast<uint32_t>(((token_id_)&0x00ffffff) | (((topk_id_)&0xff) << 24))
template
<
typename
WeightType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_moe_sorting
(
const
HostTensor
<
IndexType
>&
topk_ids
,
const
HostTensor
<
WeightType
>&
weights
,
...
...
@@ -20,8 +23,14 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
{
const
index_t
num_token
=
topk_ids
.
mDesc
.
get_lengths
()[
0
];
const
index_t
topk
=
topk_ids
.
mDesc
.
get_lengths
()[
1
];
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
// allocate a temp buffer, and fill the value with [number_token|topk]
std
::
vector
<
std
::
vector
<
IndexType
>>
expert_tokens
(
experts
,
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
std
::
vector
<
IndexType
>
(
unit_size
,
MOE_SORTING_MOCK_ID
(
num_token
,
topk
)));
#else
std
::
vector
<
IndexType
>
(
unit_size
,
num_token
));
#endif
std
::
vector
<
std
::
vector
<
WeightType
>>
expert_token_weights
(
experts
,
std
::
vector
<
WeightType
>
(
unit_size
,
0
));
std
::
vector
<
IndexType
>
expert_slices
(
experts
,
1
);
...
...
@@ -42,12 +51,19 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
expert_token_weights
[
e
].
resize
(
new_size
);
for
(
index_t
i
=
(
expert_slices
[
e
]
-
1
)
*
unit_size
;
i
<
new_size
;
i
++
)
{
expert_tokens
[
e
][
i
]
=
num_token
;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens
[
e
][
i
]
=
MOE_SORTING_MOCK_ID
(
num_token
,
topk
);
#else
expert_tokens
[
e
][
i
]
=
num_token
;
#endif
expert_token_weights
[
e
][
i
]
=
0
;
}
}
expert_tokens
[
e
][
idx
]
=
t
;
#if CK_TILE_REFERENCE_MOE_SORTING_MOCK_ID
expert_tokens
[
e
][
idx
]
=
MOE_SORTING_MOCK_ID
(
t
,
k
);
#else
expert_tokens
[
e
][
idx
]
=
t
;
#endif
expert_token_weights
[
e
][
idx
]
=
w
;
expert_slice_idxs
[
e
]
++
;
}
...
...
@@ -75,4 +91,7 @@ CK_TILE_HOST void reference_moe_sorting(const HostTensor<IndexType>& topk_ids,
unit_cnt
*=
unit_size
;
return
;
}
#undef MOE_SORTING_MOCK_ID
}
// namespace ck_tile
include/ck_tile/host/reference/reference_permute.hpp
View file @
ef5e60f6
...
...
@@ -16,7 +16,7 @@ namespace ck_tile {
*/
template
<
typename
DataType
>
CK_TILE_HOST
void
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
dims
)
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
perm
)
{
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
...
...
@@ -43,7 +43,7 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
{
tmp
[
dims
[
i
]]
=
y_coord
[
i
];
tmp
[
perm
[
i
]]
=
y_coord
[
i
];
}
return
tmp
;
}();
...
...
@@ -54,4 +54,23 @@ reference_permute(const HostTensor<DataType>& x, HostTensor<DataType>& y, std::v
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
DataType
>
CK_TILE_HOST
auto
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
std
::
vector
<
index_t
>
perm
)
{
auto
x_shape
=
x
.
get_lengths
();
ck_tile
::
index_t
rank
=
perm
.
size
();
std
::
vector
<
ck_tile
::
index_t
>
y_shape
=
[
&
]()
{
std
::
vector
<
ck_tile
::
index_t
>
tmp
(
rank
,
0
);
for
(
int
i
=
0
;
i
<
static_cast
<
int
>
(
rank
);
i
++
)
{
tmp
[
i
]
=
x_shape
[
perm
[
i
]];
}
return
tmp
;
}();
HostTensor
<
DataType
>
y
(
y_shape
);
reference_permute
(
x
,
y
,
perm
);
return
y
;
}
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
View file @
ef5e60f6
...
...
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -69,9 +70,16 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
...
...
@@ -116,8 +124,23 @@ struct AddRmsnorm2dRdquantFwdPipelineOnePass
});
// compute absmax, each-thread->cross-lane->cross-warp
auto
absmax
=
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
auto
absmax
=
[
&
]()
{
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
&&
x_size_per_row
%
2
==
0
)
{
return
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax3_func
,
sequence
<
1
,
2
>
{});
}
else
{
return
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
}
}();
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
View file @
ef5e60f6
...
...
@@ -28,8 +28,9 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseMax3
=
true
;
// TODO - Move to trait
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
...
...
@@ -76,9 +77,16 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
reduce_absmax3_func
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
...
...
@@ -177,7 +185,13 @@ struct AddRmsnorm2dRdquantFwdPipelineThreePass
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
constexpr
auto
x_size_per_row
=
x
.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
ComputeDataType
,
float
>
&&
x_size_per_row
%
2
==
0
)
block_reduce2d
(
y
,
absmax
,
reduce_absmax3_func
,
sequence
<
1
,
2
>
{});
else
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
...
...
include/ck_tile/ops/elementwise/unary_element_wise_operation.hpp
View file @
ef5e60f6
...
...
@@ -572,6 +572,105 @@ struct FastGelu
}
};
struct
FastGeluAsm
{
template
<
typename
Y
,
typename
X
>
CK_TILE_HOST
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
typename
Y
,
typename
X
>
CK_TILE_DEVICE
void
operator
()(
Y
&
y
,
const
X
&
x
)
const
;
template
<
>
CK_TILE_HOST
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
// const float u = -2.f * x * (0.035677f * x * x + 0.797885f);
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u
=
x
*
(
c1
*
x
*
x
+
c2
);
const
float
emu
=
exp
(
u
);
y
=
x
/
(
1.
f
+
emu
);
}
// device code, use lower precision "__ocml_exp_f32" and "rcp"
template
<
>
CK_TILE_DEVICE
void
operator
()
<
float
,
float
>
(
float
&
y
,
const
float
&
x
)
const
{
const
uint32_t
c1
=
0xbd92220c
;
// -2.0 * 0.035677f;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
uint32_t
log2e_
=
0x3fb8aa3b
;
// log2e_v<float>;
float
tmp
;
asm
volatile
(
"v_mul_f32 %[v_tmp], %[v_x], %[v_x] ; x*x
\n
"
"v_fma_f32 %[v_tmp], %[v_tmp], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp], %[v_tmp], %[v_x] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp], %[v_tmp], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_exp_f32 %[v_tmp], %[v_tmp] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"s_nop 0 ; hazard for exp
\n
"
"v_add_f32 %[v_tmp], %[v_tmp], 1.0 ; emu+1.0f
\n
"
"v_rcp_f32 %[v_tmp], %[v_tmp] ; 1/(emu+1.0f)
\n
"
"s_nop 0 ; hazard for rcp
\n
"
"v_mul_f32 %[v_y], %[v_tmp], %[v_x] ; x * 1/(emu+1f)
\n
"
:
[
v_y
]
"=v"
(
y
),
[
v_tmp
]
"+v"
(
tmp
)
:
[
v_x
]
"v"
(
x
),
[
s_c1
]
"s"
(
c1
),
[
v_c2
]
"v"
(
c2
),
[
s_log2e
]
"s"
(
log2e_
)
:
);
}
template
<
>
CK_TILE_HOST
void
operator
()
<
fp32x2_t
,
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
{
const
float
c1
=
-
2.0
*
0.035677
f
;
const
float
c2
=
-
2.0
*
0.797885
f
;
const
float
u0
=
x
.
x
*
(
c1
*
x
.
x
*
x
.
x
+
c2
);
const
float
emu0
=
exp
(
u0
);
y
.
x
=
x
.
x
/
(
1.
f
+
emu0
);
const
float
u1
=
x
.
y
*
(
c1
*
x
.
y
*
x
.
y
+
c2
);
const
float
emu1
=
exp
(
u1
);
y
.
y
=
x
.
y
/
(
1.
f
+
emu1
);
}
// this is packed verion to remove data hazard for trans
template
<
>
CK_TILE_DEVICE
void
operator
()
<
fp32x2_t
,
fp32x2_t
>
(
fp32x2_t
&
y
,
const
fp32x2_t
&
x
)
const
{
const
uint32_t
c1
=
0xbd92220c
;
// -2.0 * 0.035677f;
float
c2
=
-
2.0
*
0.797885
f
;
const
uint32_t
log2e_
=
0x3fb8aa3b
;
// log2e_v<float>;
float
tmp0
,
tmp1
;
float
y0
=
x
.
x
,
y1
=
x
.
y
;
asm
volatile
(
"v_mul_f32 %[v_tmp0], %[v_y0], %[v_y0] ; x*x
\n
"
"v_mul_f32 %[v_tmp1], %[v_y1], %[v_y1] ; x*x
\n
"
"v_fma_f32 %[v_tmp0], %[v_tmp0], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_fma_f32 %[v_tmp1], %[v_tmp1], %[s_c1], %[v_c2] ; c1*x*x+c2
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[v_y0] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[v_y1] ; x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp0], %[v_tmp0], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_mul_f32 %[v_tmp1], %[v_tmp1], %[s_log2e] ; log2e*x*(c1*x*x+c2)
\n
"
"v_exp_f32 %[v_tmp0], %[v_tmp0] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_exp_f32 %[v_tmp1], %[v_tmp1] ; emu = exp2(log2e*x*(c1*x*x+c2))
\n
"
"v_add_f32 %[v_tmp0], %[v_tmp0], 1.0 ; emu+1.0f
\n
"
"v_add_f32 %[v_tmp1], %[v_tmp1], 1.0 ; emu+1.0f
\n
"
"v_rcp_f32 %[v_tmp0], %[v_tmp0] ; 1/(emu+1.0f)
\n
"
"v_rcp_f32 %[v_tmp1], %[v_tmp1] ; 1/(emu+1.0f)
\n
"
"v_mul_f32 %[v_y0], %[v_tmp0], %[v_y0] ; x * 1/(emu+1f)
\n
"
"v_mul_f32 %[v_y1], %[v_tmp1], %[v_y1] ; x * 1/(emu+1f)
\n
"
:
[
v_y0
]
"+v"
(
y0
),
[
v_y1
]
"+v"
(
y1
),
[
v_c2
]
"+v"
(
c2
),
// NOTE! it is totally possible that c2/y0/y1 share same register, they are all local
// tmp variables we need to expicitly hint compiler they may read+write, to allow
// allocate different register , the side effect is c2=** may issue for every such
// inline asm block
[
v_tmp0
]
"+v"
(
tmp0
),
[
v_tmp1
]
"+v"
(
tmp1
)
:
[
s_c1
]
"s"
(
c1
),
[
s_log2e
]
"s"
(
log2e_
)
:
);
y
.
x
=
y0
;
y
.
y
=
y1
;
}
};
// https://paperswithcode.com/method/gelu
// y = 0.5*x*(1+erf(x/sqrt(2)))
struct
Gelu
...
...
include/ck_tile/ops/
moe_sorting
.hpp
→
include/ck_tile/ops/
flatmm
.hpp
View file @
ef5e60f6
...
...
@@ -3,9 +3,8 @@
#pragma once
#include "ck_tile/ops/fused_moe/kernel/moe_sorting_kernel.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_pipeline.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_policy.hpp"
#include "ck_tile/ops/fused_moe/pipeline/moe_sorting_problem.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_sn_32x128x512_1x4x1_16x16x32.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/flatmm/block/flatmm_32x512x128_1x4x1_16x16x32.hpp
0 → 100644
View file @
ef5e60f6
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/flatmm/block/flatmm_uk_config.hpp"
namespace
ck_tile
{
// A async load to LDS, B direct to AGPR
// B matrix preshuffled in br*kr*w
// require 4 wave, occupancy=1c
// agpr useage:256
// vgpr usage:64(A local) + 64(acc) + 8(os_a) + 8(os_b) = 144 (rem:112)
//
// for this gemm, 4 16x16x16 transposed layout
// input A vpgpr layout
// v0-v15: [ 0:15](gemm_m)x128(gemm_k)
// v16-v31: [16:31](gemm_m)x128(gemm_k)
// input B vpgpr layout
// v0-v15: [ 0: 15](gemm_n)x128(gemm_k)
// v16-v31: [ 64: 79](gemm_n)x128(gemm_k)
// ......................
// v111-v127: [448:463](gemm_n)x128(gemm_k)
// output C vpgpr layout
// v0-v3 : [ 0:15](gemm_m)x[ 0: 15](gemm_n)
// v4-v7 : [16:31](gemm_m)x[ 0: 15](gemm_n)
// v8-v11: [ 0:15](gemm_m)x[64: 79](gemm_n)
// v12-v15: [16:31](gemm_m)x[64: 79](gemm_n)
// ......................
// v56-v59: [ 0:15](gemm_m)x[448:463](gemm_n)
// v60-v63: [16:31](gemm_m)x[448:463](gemm_n)
struct
Flatmm_32x512x128_1x4x1_16x16x32_Base
// for f16/bf16
{
static
constexpr
index_t
Block_M
=
32
;
static
constexpr
index_t
Block_N
=
512
;
static
constexpr
index_t
Block_K
=
128
;
static
constexpr
index_t
WarpPerBlock_M
=
1
;
static
constexpr
index_t
WarpPerBlock_N
=
4
;
static
constexpr
index_t
WarpPerBlock_K
=
1
;
static
constexpr
index_t
NumWarps
=
4
;
static
constexpr
index_t
Warp_M
=
16
;
static
constexpr
index_t
Warp_N
=
16
;
static
constexpr
index_t
Warp_K
=
32
;
// 16 * SubKPacks
static
constexpr
index_t
BlockSize
=
256
;
static
constexpr
index_t
SubKPacks
=
2
;
// this is used to gurantee every threads can do dwordx4
// TODO: note Nr/Kr/W need consider SubKPacks
static
constexpr
index_t
Block_W
=
Warp_N
*
Warp_K
;
// 512 element
static
constexpr
index_t
Block_Nr
=
Block_N
/
Warp_N
;
// 32 element, 4 per wave
static
constexpr
index_t
Block_Kr
=
Block_K
/
Warp_K
;
// 4
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
Warp_M
*
WarpPerBlock_M
);
// 2
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
Warp_N
*
WarpPerBlock_N
);
// 8
static
constexpr
index_t
Repeat_K
=
Block_K
/
(
Warp_K
*
WarpPerBlock_K
);
// 8/2=4
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockDist
()
{
constexpr
auto
c_block_outer_dstr_encoding
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_N
,
WarpPerBlock_N
>>
,
tuple
<
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>>
,
sequence
<
2
,
1
>
,
// !! note here is different
sequence
<
0
,
0
>>
{};
using
WG
=
WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution
;
constexpr
auto
c_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
c_block_outer_dstr_encoding
,
typename
WG
::
CWarpDstrEncoding
{});
constexpr
auto
c_block_dstr
=
make_static_tile_distribution
(
c_block_dstr_encode
);
return
c_block_dstr
;
}
static
CK_TILE_DEVICE
constexpr
auto
MakeCBlockTile
()
{
using
CDataType
=
float
;
constexpr
auto
c_block_dstr
=
MakeCBlockDist
();
auto
c_block_tensor
=
make_static_distributed_tensor
<
CDataType
>
(
c_block_dstr
);
return
c_block_tensor
;
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsStoreDesc_A
()
{
// A async->LDS
// constexpr index_t Block_M = Problem::BlockShape::Block_M0;
// constexpr index_t Block_K = Problem::BlockShape::Block_K0;
// constexpr index_t BlockSize = Problem::BlockShape::BlockSize;
constexpr
index_t
warpSize
=
ck_tile
::
get_warp_size
();
// constexpr index_t NumWarps = Problem::BlockShape::NumWarps;
constexpr
index_t
KPack_
=
8
;
// GetSmemKPack_A<Problem>(); // LDS
constexpr
index_t
KVector
=
2
;
// GetAlignment_A<Problem>(); // async copy 1 dword
constexpr
index_t
KPad
=
KPack_
;
// pad between warps
static_assert
(
Block_K
%
KVector
==
0
);
constexpr
index_t
LanesPerK
=
Block_K
/
KVector
;
// how many thread loading K
if
constexpr
(
LanesPerK
>=
warpSize
)
{
// need multiple waves to load K
static_assert
(
LanesPerK
%
warpSize
==
0
);
constexpr
index_t
wavesPerK
=
LanesPerK
/
warpSize
;
if
constexpr
(
wavesPerK
>
NumWarps
)
{
// TODO: need multiple issues along K to load all data
}
else
{
constexpr
index_t
wavesPerM
=
NumWarps
/
wavesPerK
;
constexpr
index_t
NumIssues
=
Block_M
/
wavesPerM
;
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
wavesPerM
>
{},
// m1
number
<
wavesPerK
>
{},
// k0
number
<
warpSize
>
{},
// k1
number
<
KVector
>
{}),
// k2
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
wavesPerK
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// k0
number
<
KVector
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_merge_transform
(
make_tuple
(
number
<
wavesPerM
>
{},
number
<
wavesPerK
>
{})),
make_merge_transform
(
make_tuple
(
number
<
warpSize
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
,
2
>
{},
sequence
<
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
else
{
// lanes within a wave load different M but same K
static_assert
(
warpSize
%
LanesPerK
==
0
);
constexpr
index_t
LaneGroups
=
warpSize
/
LanesPerK
;
// along m
constexpr
index_t
NumIssues
=
Block_M
/
(
LaneGroups
*
NumWarps
);
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
NumIssues
>
{},
// m0
number
<
LaneGroups
>
{},
// m1
number
<
NumWarps
>
{},
// m2
number
<
LanesPerK
>
{},
// k0
number
<
KVector
>
{}),
// k1
make_tuple
(
number
<
NumWarps
*
(
warpSize
*
KVector
+
KPad
)
>
{},
// m0
number
<
Block_K
>
{},
// m1
number
<
warpSize
*
KVector
+
KPad
>
{},
// m2
number
<
KVector
>
{},
// k0
number
<
1
>
{}),
// k1
number
<
KVector
>
{},
// lds store vector(actually no explicit store)
number
<
1
>
{});
constexpr
auto
lds_block_desc_issues_warps_lanes
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_pass_through_transform
(
number
<
NumIssues
>
{}),
make_pass_through_transform
(
number
<
NumWarps
>
{}),
make_merge_transform
(
make_tuple
(
number
<
LaneGroups
>
{},
number
<
LanesPerK
>
{},
number
<
KVector
>
{}))),
make_tuple
(
sequence
<
0
>
{},
sequence
<
2
>
{},
sequence
<
1
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{},
sequence
<
2
>
{}));
return
lds_block_desc_issues_warps_lanes
;
}
}
// template <typename Problem>
CK_TILE_HOST_DEVICE
static
constexpr
auto
MakeLdsLoadDesc_A
()
{
// load from LDS to register, every wave has same layout
constexpr
index_t
KPack_
=
8
;
// GetSmemKPack_A<Problem>(); // LDS
constexpr
index_t
KPad
=
KPack_
;
// pad between warps
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
kKIter
=
2
;
static_assert
(
KPack_
==
(
kABKPerLane
*
kKIter
));
constexpr
auto
lds_block_desc_0
=
make_naive_tensor_descriptor
(
make_tuple
(
number
<
Repeat_M
>
{},
// m0 y
number
<
kAMLane
>
{},
// m1 p
number
<
Repeat_K
>
{},
// k0 y
number
<
kABKLane
>
{},
// k1 p
number
<
KPack_
>
{}),
// k2 y-vector
make_tuple
(
number
<
kAMLane
*
(
Block_K
+
KPad
)
>
{},
// m0
number
<
Block_K
+
KPad
>
{},
// m1
number
<
kABKLane
*
KPack_
>
{},
// k0
number
<
KPack_
>
{},
// k1
number
<
1
>
{}),
// k2
number
<
KPack_
>
{},
// lds load vector
number
<
1
>
{});
constexpr
auto
lds_desc_m_k
=
transform_tensor_descriptor
(
lds_block_desc_0
,
make_tuple
(
make_merge_transform
(
make_tuple
(
number
<
Repeat_M
>
{},
number
<
kAMLane
>
{})),
make_merge_transform
(
make_tuple
(
number
<
Repeat_K
>
{},
number
<
kABKLane
>
{},
number
<
KPack_
>
{}))),
make_tuple
(
sequence
<
0
,
1
>
{},
sequence
<
2
,
3
,
4
>
{}),
make_tuple
(
sequence
<
0
>
{},
sequence
<
1
>
{}));
return
lds_desc_m_k
;
}
static
constexpr
auto
GetGemm_AWarpEnc
()
{
constexpr
index_t
kAMLane
=
16
;
constexpr
index_t
kABKLane
=
4
;
constexpr
index_t
kABKPerLane
=
4
;
constexpr
index_t
kKIter
=
2
;
using
enc_
=
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
kAMLane
>
,
sequence
<
kABKLane
,
kABKPerLane
*
kKIter
>>
,
tuple
<
sequence
<
2
,
1
>>
,
tuple
<
sequence
<
0
,
0
>>
,
sequence
<
2
>
,
sequence
<
1
>>
;
return
enc_
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
ck_tile
::
index_t
GetSmemSize
()
{
return
32
*
(
128
+
8
)
*
sizeof
(
bf16_t
);
}
};
struct
Flatmm_32x512x128_1x4x1_16x16x32_BF16
:
public
Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using
ADataType
=
bf16_t
;
using
BDataType
=
bf16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
auto
a_sst
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsStoreDesc_A
()),
MakeLdsStoreDesc_A
().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sld
=
[
&
]()
{
constexpr
auto
a_warp_enc_
=
GetGemm_AWarpEnc
();
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
WarpPerBlock_N
>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_K
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
a_warp_enc_
);
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsLoadDesc_A
()),
MakeLdsLoadDesc_A
().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
const
index_t
tile_offset_a_bytes
=
tile_offset_a
*
sizeof
(
ADataType
);
const
index_t
tile_offset_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
constexpr
auto
smem_buf_size
=
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
ADataType
);
static_assert
(
a_sld
.
get_num_of_access
()
==
8
);
constexpr
auto
sld_os
=
generate_tuple
(
[
&
](
auto
i_access
)
{
return
number
<
a_sld
.
get_bottom_linear_offset
(
i_access
)
*
sizeof
(
ADataType
)
>
{};
},
number
<
a_sld
.
get_num_of_access
()
>
{});
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_BF16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
v_acc_0
]
"+v"
(
v_acc
[
0
]),
[
v_acc_1
]
"+v"
(
v_acc
[
1
]),
[
v_acc_2
]
"+v"
(
v_acc
[
2
]),
[
v_acc_3
]
"+v"
(
v_acc
[
3
]),
[
v_acc_4
]
"+v"
(
v_acc
[
4
]),
[
v_acc_5
]
"+v"
(
v_acc
[
5
]),
[
v_acc_6
]
"+v"
(
v_acc
[
6
]),
[
v_acc_7
]
"+v"
(
v_acc
[
7
]),
[
v_acc_8
]
"+v"
(
v_acc
[
8
]),
[
v_acc_9
]
"+v"
(
v_acc
[
9
]),
[
v_acc_10
]
"+v"
(
v_acc
[
10
]),
[
v_acc_11
]
"+v"
(
v_acc
[
11
]),
[
v_acc_12
]
"+v"
(
v_acc
[
12
]),
[
v_acc_13
]
"+v"
(
v_acc
[
13
]),
[
v_acc_14
]
"+v"
(
v_acc
[
14
]),
[
v_acc_15
]
"+v"
(
v_acc
[
15
]),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
3
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
4
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
5
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
6
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
7
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_slda
]
"v"
(
static_cast
<
index_t
>
(
a_sld
.
cached_coords_
[
number
<
0
>
{}].
get_offset
()
*
sizeof
(
ADataType
))),
[
s_m0_init
]
"s"
(
m0_init_value
),
[
s_size_per_issue
]
"s"
(
size_per_issue
),
[
smem_sz
]
"n"
(
smem_buf_size
),
//(smem_buf_size),
[
sld_os_0
]
"n"
(
sld_os
[
number
<
0
>
{}].
value
),
[
sld_os_1
]
"n"
(
sld_os
[
number
<
1
>
{}].
value
),
[
sld_os_2
]
"n"
(
sld_os
[
number
<
2
>
{}].
value
),
[
sld_os_3
]
"n"
(
sld_os
[
number
<
3
>
{}].
value
),
[
sld_os_4
]
"n"
(
sld_os
[
number
<
4
>
{}].
value
),
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_tile_os_a
]
"s"
(
tile_offset_a_bytes
),
[
s_tile_os_b
]
"s"
(
tile_offset_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s86"
,
// s86 as tmp
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
return
c
;
}
};
struct
Flatmm_32x512x128_1x4x1_16x16x32_FP16
:
public
Flatmm_32x512x128_1x4x1_16x16x32_Base
{
using
ADataType
=
fp16_t
;
using
BDataType
=
fp16_t
;
// TODO: need paired with tile_window_linear!
// TODO: need call init_raw() before call this function!
template
<
typename
ARes
,
typename
ACoords
,
typename
BRes
,
typename
BCoords
>
CK_TILE_DEVICE
auto
operator
()(
const
ARes
&
res_a
,
const
ACoords
&
cached_coords_a
,
const
BRes
&
res_b
,
const
BCoords
&
cached_coords_b
,
CK_TILE_LDS_ADDR
void
*
smem
,
index_t
k
,
index_t
tile_offset_a
,
// for each tile, the offset to move for each unroll
index_t
tile_offset_b
)
// for each tile, the offset to move for each unroll
{
static_assert
(
ACoords
::
size
()
==
Block_M
*
Block_K
/
BlockSize
/
2
/*2x per dword*/
);
// 8
static_assert
(
BCoords
::
size
()
==
Repeat_N
);
auto
a_sst
=
make_tile_window
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsStoreDesc_A
()),
MakeLdsStoreDesc_A
().
get_lengths
(),
{
0
,
0
,
0
});
auto
a_sld
=
[
&
]()
{
constexpr
auto
a_warp_enc_
=
GetGemm_AWarpEnc
();
constexpr
auto
a_outer_dstr_enc
=
tile_distribution_encoding
<
sequence
<
WarpPerBlock_N
>
,
tuple
<
sequence
<
Repeat_M
,
WarpPerBlock_M
>
,
sequence
<
Repeat_K
>>
,
tuple
<
sequence
<
1
,
0
>>
,
tuple
<
sequence
<
1
,
0
>>
,
sequence
<
1
,
2
>
,
sequence
<
0
,
0
>>
{};
constexpr
auto
a_block_dstr_encode
=
detail
::
make_embed_tile_distribution_encoding
(
a_outer_dstr_enc
,
a_warp_enc_
);
return
make_tile_window_linear
(
make_tensor_view
<
address_space_enum
::
lds
>
(
reinterpret_cast
<
CK_TILE_LDS_ADDR
ADataType
*>
(
smem
),
MakeLdsLoadDesc_A
()),
MakeLdsLoadDesc_A
().
get_lengths
(),
{
0
,
0
},
make_static_tile_distribution
(
a_block_dstr_encode
));
}();
const
index_t
tile_offset_a_bytes
=
tile_offset_a
*
sizeof
(
ADataType
);
const
index_t
tile_offset_b_bytes
=
tile_offset_b
*
sizeof
(
BDataType
);
const
auto
[
m0_init_value
,
size_per_issue
]
=
get_async_store_smem_info
(
a_sst
);
constexpr
auto
smem_buf_size
=
MakeLdsLoadDesc_A
().
get_element_space_size
()
*
sizeof
(
ADataType
);
static_assert
(
a_sld
.
get_num_of_access
()
==
8
);
constexpr
auto
sld_os
=
generate_tuple
(
[
&
](
auto
i_access
)
{
return
number
<
a_sld
.
get_bottom_linear_offset
(
i_access
)
*
sizeof
(
ADataType
)
>
{};
},
number
<
a_sld
.
get_num_of_access
()
>
{});
index_t
loop_cnt
=
k
/
Block_K
;
// this is the acc thread buffer
fp32x4_t
v_acc
[
16
]{
.0
f
};
// B nr->kr
#pragma clang diagnostic push
#pragma clang diagnostic ignored "-Winline-asm"
// clang-format off
asm
volatile
(
#define CK_TILE_FLATMM_UK_MFMA CK_TILE_FLATMM_UK_MFMA_FP16
#include "uk/flatmm_uk_gfx9_32x512x128_1x1x1_16x16x16.inc"
#undef CK_TILE_FLATMM_UK_MFMA
:
[
s_loop_cnt
]
"+s"
(
loop_cnt
),
[
v_acc_0
]
"+v"
(
v_acc
[
0
]),
[
v_acc_1
]
"+v"
(
v_acc
[
1
]),
[
v_acc_2
]
"+v"
(
v_acc
[
2
]),
[
v_acc_3
]
"+v"
(
v_acc
[
3
]),
[
v_acc_4
]
"+v"
(
v_acc
[
4
]),
[
v_acc_5
]
"+v"
(
v_acc
[
5
]),
[
v_acc_6
]
"+v"
(
v_acc
[
6
]),
[
v_acc_7
]
"+v"
(
v_acc
[
7
]),
[
v_acc_8
]
"+v"
(
v_acc
[
8
]),
[
v_acc_9
]
"+v"
(
v_acc
[
9
]),
[
v_acc_10
]
"+v"
(
v_acc
[
10
]),
[
v_acc_11
]
"+v"
(
v_acc
[
11
]),
[
v_acc_12
]
"+v"
(
v_acc
[
12
]),
[
v_acc_13
]
"+v"
(
v_acc
[
13
]),
[
v_acc_14
]
"+v"
(
v_acc
[
14
]),
[
v_acc_15
]
"+v"
(
v_acc
[
15
]),
[
s_mem_
]
"+r"
(
smem
)
:
[
s_res_a0
]
"s"
(
res_a
[
0
]),
[
s_res_a1
]
"s"
(
res_a
[
1
]),
[
s_res_a2
]
"s"
(
res_a
[
2
]),
[
s_res_a3
]
"s"
(
res_a
[
3
]),
[
s_res_b0
]
"s"
(
res_b
[
0
]),
[
s_res_b1
]
"s"
(
res_b
[
1
]),
[
s_res_b2
]
"s"
(
res_b
[
2
]),
[
s_res_b3
]
"s"
(
res_b
[
3
]),
[
v_os_a0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
0
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
1
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
2
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
3
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
4
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
5
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
6
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_a7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_a
[
number
<
7
>
{}]
*
sizeof
(
ADataType
))),
[
v_os_b0
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
0
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b1
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
1
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b2
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
2
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b3
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
3
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b4
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
4
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b5
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
5
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b6
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
6
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_b7
]
"v"
(
static_cast
<
index_t
>
(
cached_coords_b
[
number
<
7
>
{}]
*
sizeof
(
BDataType
))),
[
v_os_slda
]
"v"
(
static_cast
<
index_t
>
(
a_sld
.
cached_coords_
[
number
<
0
>
{}].
get_offset
()
*
sizeof
(
ADataType
))),
[
s_m0_init
]
"s"
(
m0_init_value
),
[
s_size_per_issue
]
"s"
(
size_per_issue
),
[
smem_sz
]
"n"
(
smem_buf_size
),
//(smem_buf_size),
[
sld_os_0
]
"n"
(
sld_os
[
number
<
0
>
{}].
value
),
[
sld_os_1
]
"n"
(
sld_os
[
number
<
1
>
{}].
value
),
[
sld_os_2
]
"n"
(
sld_os
[
number
<
2
>
{}].
value
),
[
sld_os_3
]
"n"
(
sld_os
[
number
<
3
>
{}].
value
),
[
sld_os_4
]
"n"
(
sld_os
[
number
<
4
>
{}].
value
),
[
sld_os_5
]
"n"
(
sld_os
[
number
<
5
>
{}].
value
),
[
sld_os_6
]
"n"
(
sld_os
[
number
<
6
>
{}].
value
),
[
sld_os_7
]
"n"
(
sld_os
[
number
<
7
>
{}].
value
),
[
s_tile_os_a
]
"s"
(
tile_offset_a_bytes
),
[
s_tile_os_b
]
"s"
(
tile_offset_b_bytes
)
:
"memory"
,
"a0"
,
"a1"
,
"a2"
,
"a3"
,
"a4"
,
"a5"
,
"a6"
,
"a7"
,
"a8"
,
"a9"
,
"a10"
,
"a11"
,
"a12"
,
"a13"
,
"a14"
,
"a15"
,
"a16"
,
"a17"
,
"a18"
,
"a19"
,
"a20"
,
"a21"
,
"a22"
,
"a23"
,
"a24"
,
"a25"
,
"a26"
,
"a27"
,
"a28"
,
"a29"
,
"a30"
,
"a31"
,
"a32"
,
"a33"
,
"a34"
,
"a35"
,
"a36"
,
"a37"
,
"a38"
,
"a39"
,
"a40"
,
"a41"
,
"a42"
,
"a43"
,
"a44"
,
"a45"
,
"a46"
,
"a47"
,
"a48"
,
"a49"
,
"a50"
,
"a51"
,
"a52"
,
"a53"
,
"a54"
,
"a55"
,
"a56"
,
"a57"
,
"a58"
,
"a59"
,
"a60"
,
"a61"
,
"a62"
,
"a63"
,
"a64"
,
"a65"
,
"a66"
,
"a67"
,
"a68"
,
"a69"
,
"a70"
,
"a71"
,
"a72"
,
"a73"
,
"a74"
,
"a75"
,
"a76"
,
"a77"
,
"a78"
,
"a79"
,
"a80"
,
"a81"
,
"a82"
,
"a83"
,
"a84"
,
"a85"
,
"a86"
,
"a87"
,
"a88"
,
"a89"
,
"a90"
,
"a91"
,
"a92"
,
"a93"
,
"a94"
,
"a95"
,
"a96"
,
"a97"
,
"a98"
,
"a99"
,
"a100"
,
"a101"
,
"a102"
,
"a103"
,
"a104"
,
"a105"
,
"a106"
,
"a107"
,
"a108"
,
"a109"
,
"a110"
,
"a111"
,
"a112"
,
"a113"
,
"a114"
,
"a115"
,
"a116"
,
"a117"
,
"a118"
,
"a119"
,
"a120"
,
"a121"
,
"a122"
,
"a123"
,
"a124"
,
"a125"
,
"a126"
,
"a127"
,
"a128"
,
"a129"
,
"a130"
,
"a131"
,
"a132"
,
"a133"
,
"a134"
,
"a135"
,
"a136"
,
"a137"
,
"a138"
,
"a139"
,
"a140"
,
"a141"
,
"a142"
,
"a143"
,
"a144"
,
"a145"
,
"a146"
,
"a147"
,
"a148"
,
"a149"
,
"a150"
,
"a151"
,
"a152"
,
"a153"
,
"a154"
,
"a155"
,
"a156"
,
"a157"
,
"a158"
,
"a159"
,
"a160"
,
"a161"
,
"a162"
,
"a163"
,
"a164"
,
"a165"
,
"a166"
,
"a167"
,
"a168"
,
"a169"
,
"a170"
,
"a171"
,
"a172"
,
"a173"
,
"a174"
,
"a175"
,
"a176"
,
"a177"
,
"a178"
,
"a179"
,
"a180"
,
"a181"
,
"a182"
,
"a183"
,
"a184"
,
"a185"
,
"a186"
,
"a187"
,
"a188"
,
"a189"
,
"a190"
,
"a191"
,
"a192"
,
"a193"
,
"a194"
,
"a195"
,
"a196"
,
"a197"
,
"a198"
,
"a199"
,
"a200"
,
"a201"
,
"a202"
,
"a203"
,
"a204"
,
"a205"
,
"a206"
,
"a207"
,
"a208"
,
"a209"
,
"a210"
,
"a211"
,
"a212"
,
"a213"
,
"a214"
,
"a215"
,
"a216"
,
"a217"
,
"a218"
,
"a219"
,
"a220"
,
"a221"
,
"a222"
,
"a223"
,
"a224"
,
"a225"
,
"a226"
,
"a227"
,
"a228"
,
"a229"
,
"a230"
,
"a231"
,
"a232"
,
"a233"
,
"a234"
,
"a235"
,
"a236"
,
"a237"
,
"a238"
,
"a239"
,
"a240"
,
"a241"
,
"a242"
,
"a243"
,
"a244"
,
"a245"
,
"a246"
,
"a247"
,
"a248"
,
"a249"
,
"a250"
,
"a251"
,
"a252"
,
"a253"
,
"a254"
,
"a255"
,
"s16"
,
"s17"
,
"s18"
,
"s19"
,
"s20"
,
"s21"
,
"s22"
,
"s23"
,
"s86"
,
// s86 as tmp
"v64"
,
"v65"
,
"v66"
,
"v67"
,
"v68"
,
"v69"
,
"v70"
,
"v71"
,
"v72"
,
"v73"
,
"v74"
,
"v75"
,
"v76"
,
"v77"
,
"v78"
,
"v79"
,
"v80"
,
"v81"
,
"v82"
,
"v83"
,
"v84"
,
"v85"
,
"v86"
,
"v87"
,
"v88"
,
"v89"
,
"v90"
,
"v91"
,
"v92"
,
"v93"
,
"v94"
,
"v95"
,
"v96"
,
"v97"
,
"v98"
,
"v99"
,
"v100"
,
"v101"
,
"v102"
,
"v103"
,
"v104"
,
"v105"
,
"v106"
,
"v107"
,
"v108"
,
"v109"
,
"v110"
,
"v111"
,
"v112"
,
"v113"
,
"v114"
,
"v115"
,
"v116"
,
"v117"
,
"v118"
,
"v119"
,
"v120"
,
"v121"
,
"v122"
,
"v123"
,
"v124"
,
"v125"
,
"v126"
,
"v127"
);
// clang-format on
#pragma clang diagnostic pop
// return local scratch
auto
c
=
MakeCBlockTile
();
for
(
auto
i
=
0
;
i
<
16
;
i
++
)
{
c
.
get_thread_buffer
()[
4
*
i
+
0
]
=
v_acc
[
i
].
x
;
c
.
get_thread_buffer
()[
4
*
i
+
1
]
=
v_acc
[
i
].
y
;
c
.
get_thread_buffer
()[
4
*
i
+
2
]
=
v_acc
[
i
].
z
;
c
.
get_thread_buffer
()[
4
*
i
+
3
]
=
v_acc
[
i
].
w
;
}
return
c
;
}
};
}
// namespace ck_tile
Prev
1
…
6
7
8
9
10
11
12
13
14
…
21
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment