Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
ead5167a
Commit
ead5167a
authored
Nov 01, 2024
by
dummycoderfe
Browse files
merge develop
parents
da1a2829
03c6448b
Changes
137
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
225 additions
and
113 deletions
+225
-113
include/ck_tile/core/numeric/math.hpp
include/ck_tile/core/numeric/math.hpp
+1
-1
include/ck_tile/core/numeric/type_convert.hpp
include/ck_tile/core/numeric/type_convert.hpp
+4
-0
include/ck_tile/core/tensor/null_tile_window.hpp
include/ck_tile/core/tensor/null_tile_window.hpp
+7
-0
include/ck_tile/host/reference/reference_elementwise.hpp
include/ck_tile/host/reference/reference_elementwise.hpp
+1
-1
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
+32
-5
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+1
-1
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
+1
-1
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+1
-1
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp
...orm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp
+9
-8
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp
...norm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp
+0
-78
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
...ine/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
+1
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
...t/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
+1
-1
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/generic_2d_block_shape.hpp
include/ck_tile/ops/common/generic_2d_block_shape.hpp
+3
-4
include/ck_tile/ops/elementwise.hpp
include/ck_tile/ops/elementwise.hpp
+1
-0
include/ck_tile/ops/epilogue.hpp
include/ck_tile/ops/epilogue.hpp
+2
-0
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
+17
-11
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
+140
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+1
-0
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
+1
-1
No files found.
include/ck_tile/core/numeric/math.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/core/numeric/type_convert.hpp
View file @
ead5167a
...
@@ -10,6 +10,7 @@
...
@@ -10,6 +10,7 @@
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/half.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/bfloat16.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/float8.hpp"
#include "ck_tile/core/numeric/int8.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
...
@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
...
@@ -60,6 +61,9 @@ CK_TILE_TYPE_CONVERT(bf16_t, bf16, float, float)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
fp8_t
,
fp8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
bf8_t
,
bf8
,
float
,
float
)
CK_TILE_TYPE_CONVERT
(
float
,
float
,
int8_t
,
int8
)
CK_TILE_TYPE_CONVERT
(
int8_t
,
int8
,
float
,
float
)
#undef CK_TILE_TYPE_CONVERT
#undef CK_TILE_TYPE_CONVERT
#endif
#endif
...
...
include/ck_tile/core/tensor/null_tile_window.hpp
View file @
ead5167a
...
@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
...
@@ -80,6 +80,13 @@ CK_TILE_DEVICE constexpr auto make_tile_window(null_tensor_view,
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
return
null_tile_window
<
remove_cvref_t
<
WindowLengths
>>
{
window_lengths
};
}
}
template
<
typename
WindowLengths
,
typename
StaticTileDistribution
>
CK_TILE_DEVICE
constexpr
auto
make_tile_window
(
const
null_tile_window
<
WindowLengths
>&
t
,
const
StaticTileDistribution
&
)
{
return
t
;
}
template
<
typename
WindowLengths
>
template
<
typename
WindowLengths
>
CK_TILE_DEVICE
void
CK_TILE_DEVICE
void
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
move_tile_window
(
null_tile_window
<
WindowLengths
>&
,
...
...
include/ck_tile/host/reference/reference_elementwise.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
View file @
ead5167a
...
@@ -8,20 +8,44 @@
...
@@ -8,20 +8,44 @@
namespace
ck_tile
{
namespace
ck_tile
{
// Note: for simplicity, each functor only care about single M
struct
reference_layernorm2d_default_epilogue
{
template
<
typename
OutDataType
,
typename
AccDataType
>
void
operator
()(
int
m
,
HostTensor
<
OutDataType
>&
o
,
const
HostTensor
<
AccDataType
>&
acc
)
{
const
int
N
=
acc
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
o
(
m
,
n
)
=
ck_tile
::
type_convert
<
OutDataType
>
(
acc
(
m
,
n
));
}
}
template
<
typename
OutDataType
,
typename
AccDataType
>
auto
operator
()(
int
m
,
const
HostTensor
<
AccDataType
>&
acc
)
{
HostTensor
<
OutDataType
>
o
(
acc
.
get_lengths
(),
acc
.
get_strides
());
operator
()(
m
,
o
,
acc
);
return
o
;
}
};
template
<
typename
XDataType
,
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
typename
InvStdDataType
,
typename
Epilogue
=
reference_layernorm2d_default_epilogue
>
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
,
Epilogue
epilogue_functor
=
{})
{
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
...
@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
...
@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
HostTensor
<
ComputeDataType
>
acc
(
x_m_n
.
get_lengths
(),
x_m_n
.
get_strides
());
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
y
=
(
x
-
mean
)
*
divisor
;
auto
a_
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
a_
=
a_
*
gamma
+
beta
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
)
;
acc
(
m
,
n
)
=
a_
;
}
}
epilogue_functor
(
m
,
y_m_n
,
acc
);
};
};
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
...
...
include/ck_tile/host/reference/reference_permute.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
View file @
ead5167a
...
@@ -4,9 +4,9 @@
...
@@ -4,9 +4,9 @@
#pragma once
#pragma once
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp
View file @
ead5167a
...
@@ -9,15 +9,16 @@
...
@@ -9,15 +9,16 @@
namespace
ck_tile
{
namespace
ck_tile
{
// host side args
// host side args
// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
struct
AddRmsnorm2dRdquantFwdHostArgs
struct
AddRmsnorm2dRdquantFwdHostArgs
{
{
const
void
*
p_a
;
const
void
*
p_a
;
// [m ,n], input, fp16/bf16
const
void
*
p_b
;
const
void
*
p_b
;
// [m ,n], input, fp16/bf16
const
void
*
p_gamma
;
const
void
*
p_gamma
;
// [1, n], gamma, prec same as input
void
*
p_x
;
void
*
p_x
;
// [m, n], output, p_a + p_b, fp16/bf16
void
*
p_yscale
;
void
*
p_yscale
;
// [m, 1], output, rowwise quant scale (amax / 127) of reuslt of rmsnorm2d(x)
void
*
p_qy
;
void
*
p_qy
;
// [m, n], output, result of quant tensor of rmsnorm2d(x) int8
float
epsilon
;
float
epsilon
;
...
@@ -90,7 +91,7 @@ struct AddRmsnorm2dRdquantFwd
...
@@ -90,7 +91,7 @@ struct AddRmsnorm2dRdquantFwd
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
{
return
integer_divide_ceil
(
hargs
.
m
,
Block_M
);
return
dim3
(
integer_divide_ceil
(
hargs
.
m
,
Block_M
)
)
;
}
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
...
@@ -170,7 +171,7 @@ struct AddRmsnorm2dRdquantFwd
...
@@ -170,7 +171,7 @@ struct AddRmsnorm2dRdquantFwd
number
<
1
>
{});
number
<
1
>
{});
const
auto
tmp2_
=
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPad
M
>
{});
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPad
N
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
}();
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp
deleted
100644 → 0
View file @
da1a2829
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
AddRmsnorm2dRdquantShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
View file @
ead5167a
...
@@ -26,6 +26,7 @@ struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
...
@@ -26,6 +26,7 @@ struct AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
sequence
<
0
,
3
,
0
,
3
>>
{});
}
}
template
<
typename
Problem
>
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
{
{
...
...
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
include/ck_tile/ops/common.hpp
View file @
ead5167a
...
@@ -3,4 +3,5 @@
...
@@ -3,4 +3,5 @@
#pragma once
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/
rmsnorm2d/kernel/rmsnorm2d_fwd
_shape.hpp
→
include/ck_tile/ops/
common/generic_2d_block
_shape.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
namespace
ck_tile
{
/*
/*
// clang-format off
// clang-format off
...
@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
...
@@ -42,7 +41,7 @@ template <typename BlockTile_, // block size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Rmsnorm2d
Shape
struct
Generic2dBlock
Shape
{
{
// block size
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
...
...
include/ck_tile/ops/elementwise.hpp
View file @
ead5167a
...
@@ -4,4 +4,5 @@
...
@@ -4,4 +4,5 @@
#pragma once
#pragma once
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/elementwise/unary_element_wise_operation.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue.hpp
View file @
ead5167a
...
@@ -5,4 +5,6 @@
...
@@ -5,4 +5,6 @@
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/cshuffle_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/default_2d_epilogue.hpp"
#include "ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/epilogue/default_2d_epilogue.hpp
View file @
ead5167a
...
@@ -9,13 +9,18 @@ namespace ck_tile {
...
@@ -9,13 +9,18 @@ namespace ck_tile {
// this epilogue just store out a M*N matrix, row major
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
>
template
<
typename
AccDataType_
,
typename
ODataType_
,
bool
kPadM_
,
bool
kPadN_
,
bool
UseRawStore_
=
true
>
struct
Default2DEpilogueProblem
struct
Default2DEpilogueProblem
{
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
};
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
template
<
typename
Problem_
,
typename
Policy_
=
void
>
...
@@ -26,6 +31,7 @@ struct Default2DEpilogue
...
@@ -26,6 +31,7 @@ struct Default2DEpilogue
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadM
=
Problem
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
UseRawStore
;
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
0
;
}
...
@@ -36,7 +42,7 @@ struct Default2DEpilogue
...
@@ -36,7 +42,7 @@ struct Default2DEpilogue
{
{
// TODO: this is ugly
// TODO: this is ugly
if
constexpr
(
kPadM
||
kPadN
)
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
)
)
{
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_tile
));
buffer_store_fence
();
buffer_store_fence
();
...
...
include/ck_tile/ops/epilogue/dynamic_quant_epilogue.hpp
0 → 100644
View file @
ead5167a
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce.hpp"
namespace
ck_tile
{
template
<
bool
kPadM_
,
bool
kPadN_
,
bool
UseRawStore_
=
true
,
bool
UseMax3_
=
false
>
struct
DynamicQuantEpilogueTraits
{
static
constexpr
bool
kPadM
=
kPadM_
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
UseRawStore
=
UseRawStore_
;
static
constexpr
bool
UseMax3
=
UseMax3_
;
};
// this epilogue just store out a M*N matrix, row major
template
<
typename
AccDataType_
,
typename
YScaleDataType_
,
typename
ODataType_
,
typename
BlockShape_
,
typename
Traits_
>
struct
DynamicQuantEpilogueProblem
{
using
AccDataType
=
remove_cvref_t
<
AccDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
ODataType
=
remove_cvref_t
<
ODataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
// can consum generic 2d shape
using
Traits
=
remove_cvref_t
<
Traits_
>
;
};
template
<
typename
Problem_
,
typename
Policy_
=
void
>
struct
DynamicQuantEpilogue
{
using
Problem
=
remove_cvref_t
<
Problem_
>
;
using
AccDataType
=
remove_cvref_t
<
typename
Problem
::
AccDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
ODataType
=
remove_cvref_t
<
typename
Problem
::
ODataType
>
;
using
BlockShape
=
remove_cvref_t
<
typename
Problem
::
BlockShape
>
;
static
constexpr
bool
kPadM
=
Problem
::
Traits
::
kPadM
;
static
constexpr
bool
kPadN
=
Problem
::
Traits
::
kPadN
;
static
constexpr
bool
UseRawStore
=
Problem
::
Traits
::
UseRawStore
;
static
constexpr
bool
UseMax3
=
Problem
::
Traits
::
UseMax3
;
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
AccDataType
,
AccDataType
,
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
return
reduce_crosswarp_sync
.
GetSmemSize
();
}
// TODO: this function assume store out vector size is the same as OAccTile last dimension size
// how do we fix this ?
template
<
typename
ODramWindowTmp
,
typename
YScaleWindow
,
typename
OAccTile
>
CK_TILE_DEVICE
auto
operator
()(
ODramWindowTmp
&
o_dram_window_tmp
,
YScaleWindow
&
y_scale_window
,
const
OAccTile
&
o_acc_tile
,
void
*
smem
)
{
auto
reduce
=
GetBlockReduce2d
();
auto
reduce_sync
=
GetBlockReduce2dSync
();
auto
reduce_crosswarp_sync
=
GetBlockReduce2dCrossWarpSync
();
const
auto
f_absmax
=
[](
auto
acc_
,
auto
v_0_
)
{
return
max
(
acc_
,
abs
(
v_0_
));
};
auto
row_absmax
=
[
&
]()
{
constexpr
auto
y_size_per_row
=
OAccTile
{}.
get_tile_distribution
().
get_ys_to_d_descriptor
().
get_lengths
().
at
(
number
<
1
>
{});
// constexpr auto y_size_per_row = OAccTile::get_lengths()[number<1>{}];
if
constexpr
(
UseMax3
&&
std
::
is_same_v
<
AccDataType
,
float
>
&&
y_size_per_row
%
2
==
0
)
{
// fast max3 implementation
const
auto
f_max3
=
[](
auto
acc_
,
auto
v_0_
,
auto
v_1_
)
{
float
rtn
;
asm
volatile
(
"v_max3_f32 %0, %1, abs(%2), abs(%3)"
:
"=v"
(
rtn
)
:
"v"
(
acc_
),
"v"
(
v_0_
),
"v"
(
v_1_
));
return
rtn
;
};
return
reduce
(
o_acc_tile
,
type_convert
<
AccDataType
>
(
0
),
f_max3
,
sequence
<
1
,
2
>
{});
}
else
{
return
reduce
(
o_acc_tile
,
type_convert
<
AccDataType
>
(
0
),
f_absmax
);
}
}();
reduce_sync
(
row_absmax
,
f_absmax
);
reduce_crosswarp_sync
(
row_absmax
,
smem
,
f_absmax
);
// here y_scale is Acc TYpe, need convert to YScale type later
auto
y_scale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
AccDataType
>
(
numeric
<
ODataType
>::
max
());
},
row_absmax
);
store_tile
(
y_scale_window
,
cast_tile
<
YScaleDataType
>
(
y_scale
));
auto
o_acc_scaled_tile
=
make_static_distributed_tensor
<
AccDataType
>
(
o_acc_tile
.
get_tile_distribution
());
sweep_tile
(
o_acc_tile
,
[
&
](
auto
idx
)
{
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
o_acc_scaled_tile
(
idx
)
=
o_acc_tile
[
idx
]
/
y_scale
(
row_id
);
});
// TODO: this is ugly
if
constexpr
(
UseRawStore
&&
(
kPadM
||
kPadN
))
{
store_tile_raw
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_scaled_tile
));
buffer_store_fence
();
}
else
{
store_tile
(
o_dram_window_tmp
,
cast_tile
<
ODataType
>
(
o_acc_scaled_tile
));
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/fmha.hpp
View file @
ead5167a
...
@@ -43,4 +43,5 @@
...
@@ -43,4 +43,5 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/fmha/pipeline/tile_fmha_traits.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/fmha/pipeline/tile_fmha_shape.hpp
View file @
ead5167a
// SPDX-License-Identifier: MIT
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#pragma once
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment