Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
7d50244e
Unverified
Commit
7d50244e
authored
Oct 31, 2024
by
Illia Silin
Committed by
GitHub
Oct 31, 2024
Browse files
Merge pull request #209 from ROCm/andriy/merge_from_public
Update develop branch from public repository
parents
f221c2b0
d51701d4
Changes
291
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1476 additions
and
73 deletions
+1476
-73
include/ck_tile/host/fill.hpp
include/ck_tile/host/fill.hpp
+68
-0
include/ck_tile/host/host_tensor.hpp
include/ck_tile/host/host_tensor.hpp
+23
-0
include/ck_tile/host/reference/reference_elementwise.hpp
include/ck_tile/host/reference/reference_elementwise.hpp
+47
-0
include/ck_tile/host/reference/reference_gemm.hpp
include/ck_tile/host/reference/reference_gemm.hpp
+21
-39
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
+32
-5
include/ck_tile/host/reference/reference_permute.hpp
include/ck_tile/host/reference/reference_permute.hpp
+57
-0
include/ck_tile/host/reference/reference_reduce.hpp
include/ck_tile/host/reference/reference_reduce.hpp
+9
-8
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
+52
-0
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
..._tile/host/reference/reference_rowwise_quantization2d.hpp
+33
-0
include/ck_tile/host/reference/reference_softmax.hpp
include/ck_tile/host/reference/reference_softmax.hpp
+59
-21
include/ck_tile/host/reference/reference_topk.hpp
include/ck_tile/host/reference/reference_topk.hpp
+124
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
+13
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp
...orm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp
+239
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp
...norm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp
+78
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
...ine/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
+94
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
.../pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
+142
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
...t/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
+41
-0
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
...ipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
+266
-0
include/ck_tile/ops/common.hpp
include/ck_tile/ops/common.hpp
+1
-0
include/ck_tile/ops/common/generic_2d_block_shape.hpp
include/ck_tile/ops/common/generic_2d_block_shape.hpp
+77
-0
No files found.
include/ck_tile/host/fill.hpp
View file @
7d50244e
...
...
@@ -10,6 +10,7 @@
#include <random>
#include <type_traits>
#include <utility>
#include <unordered_set>
#include "ck_tile/core.hpp"
...
...
@@ -41,6 +42,73 @@ struct FillUniformDistribution
}
};
namespace
impl
{
// clang-format off
template
<
index_t
bytes
>
struct
RawIntegerType_
{};
template
<
>
struct
RawIntegerType_
<
1
>
{
using
type
=
uint8_t
;};
template
<
>
struct
RawIntegerType_
<
2
>
{
using
type
=
uint16_t
;};
template
<
>
struct
RawIntegerType_
<
4
>
{
using
type
=
uint32_t
;};
template
<
>
struct
RawIntegerType_
<
8
>
{
using
type
=
uint64_t
;};
// clang-format on
template
<
typename
T
>
using
RawIntegerType
=
typename
RawIntegerType_
<
sizeof
(
T
)
>::
type
;
}
// namespace impl
// Note: this struct will have no const-ness will generate random
template
<
typename
T
>
struct
FillUniformDistribution_Unique
{
float
a_
{
-
5.
f
};
float
b_
{
5.
f
};
std
::
optional
<
uint32_t
>
seed_
{
11939
};
std
::
mt19937
gen_
{};
std
::
unordered_set
<
impl
::
RawIntegerType
<
T
>>
set_
{};
FillUniformDistribution_Unique
(
float
a
=
-
5.
f
,
float
b
=
5.
f
,
std
::
optional
<
uint32_t
>
seed
=
{
11939
})
:
a_
(
a
),
b_
(
b
),
seed_
(
seed
),
gen_
{
seed_
.
has_value
()
?
*
seed_
:
std
::
random_device
{}()},
set_
{}
{
}
template
<
typename
ForwardIter
>
void
operator
()(
ForwardIter
first
,
ForwardIter
last
)
{
std
::
mt19937
&
gen
=
gen_
;
std
::
uniform_real_distribution
<
float
>
dis
(
a_
,
b_
);
auto
&
set
=
set_
;
std
::
generate
(
first
,
last
,
[
&
dis
,
&
gen
,
&
set
]()
{
T
v
=
static_cast
<
T
>
(
0
);
do
{
v
=
ck_tile
::
type_convert
<
T
>
(
dis
(
gen
));
}
while
(
set
.
count
(
bit_cast
<
impl
::
RawIntegerType
<
T
>>
(
v
))
==
1
);
set
.
insert
(
bit_cast
<
impl
::
RawIntegerType
<
T
>>
(
v
));
return
v
;
});
}
template
<
typename
ForwardRange
>
auto
operator
()(
ForwardRange
&&
range
)
->
std
::
void_t
<
decltype
(
std
::
declval
<
FillUniformDistribution_Unique
&>
()(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
))))
>
{
(
*
this
)(
std
::
begin
(
std
::
forward
<
ForwardRange
>
(
range
)),
std
::
end
(
std
::
forward
<
ForwardRange
>
(
range
)));
}
void
clear
()
{
set_
.
clear
();
}
};
template
<
typename
T
>
struct
FillNormalDistribution
{
...
...
include/ck_tile/host/host_tensor.hpp
View file @
7d50244e
...
...
@@ -11,6 +11,7 @@
#include <thread>
#include <utility>
#include <vector>
#include <functional>
#include "ck_tile/core.hpp"
#include "ck_tile/host/ranges.hpp"
...
...
@@ -545,6 +546,28 @@ struct HostTensor
typename
Data
::
size_type
size
()
const
{
return
mData
.
size
();
}
// return a slice of this tensor
// for simplicity we just copy the data and return a new tensor
auto
slice
(
std
::
vector
<
size_t
>
s_begin
,
std
::
vector
<
size_t
>
s_end
)
const
{
assert
(
s_begin
.
size
()
==
s_end
.
size
());
assert
(
s_begin
.
size
()
==
get_num_of_dimension
());
std
::
vector
<
size_t
>
s_len
(
s_begin
.
size
());
std
::
transform
(
s_end
.
begin
(),
s_end
.
end
(),
s_begin
.
begin
(),
s_len
.
begin
(),
std
::
minus
<
size_t
>
{});
HostTensor
<
T
>
sliced_tensor
(
s_len
);
sliced_tensor
.
ForEach
([
&
](
auto
&
self
,
auto
idx
)
{
std
::
vector
<
size_t
>
src_idx
(
idx
.
size
());
std
::
transform
(
idx
.
begin
(),
idx
.
end
(),
s_begin
.
begin
(),
src_idx
.
begin
(),
std
::
plus
<
size_t
>
{});
self
(
idx
)
=
operator
()(
src_idx
);
});
return
sliced_tensor
;
}
template
<
typename
U
=
T
>
auto
AsSpan
()
const
{
...
...
include/ck_tile/host/reference/reference_elementwise.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
ADataType
,
typename
BDataType
,
typename
ComputeDataType
,
typename
ElementOp
>
CK_TILE_HOST
void
reference_unary_elementwise
(
const
HostTensor
<
ADataType
>&
a
,
HostTensor
<
BDataType
>&
b
,
ElementOp
element_op
)
{
// TODO: imeplement gpu version reference function
auto
f
=
[
&
](
auto
i
)
{
auto
v_a
=
type_convert
<
ComputeDataType
>
(
a
.
mData
[
i
]);
auto
v_b
=
element_op
(
v_a
);
b
.
mData
[
i
]
=
ck_tile
::
type_convert
<
BDataType
>
(
v_b
);
};
make_ParallelTensorFunctor
(
f
,
b
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
typename
BDataType
,
typename
CDataType
,
typename
ComputeDataType
,
typename
ElementOp
>
CK_TILE_HOST
void
reference_binary_elementwise
(
const
HostTensor
<
ADataType
>&
a
,
const
HostTensor
<
BDataType
>&
b
,
HostTensor
<
CDataType
>&
c
,
ElementOp
element_op
)
{
// TODO: imeplement gpu version reference function
auto
f
=
[
&
](
auto
i
)
{
auto
v_a
=
type_convert
<
ComputeDataType
>
(
a
.
mData
[
i
]);
auto
v_b
=
type_convert
<
ComputeDataType
>
(
b
.
mData
[
i
]);
auto
v_c
=
element_op
(
v_a
,
v_b
);
c
.
mData
[
i
]
=
ck_tile
::
type_convert
<
CDataType
>
(
v_c
);
};
make_ParallelTensorFunctor
(
f
,
c
.
get_element_space_size
())(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_gemm.hpp
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-202
3
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-202
4
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <cstdlib>
#include <thread>
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include <thread>
namespace
ck_tile
{
...
...
@@ -14,55 +15,36 @@ template <typename ADataType,
typename
BDataType
,
typename
AccDataType
,
typename
CDataType
,
typename
LayoutA
,
typename
LayoutB
,
typename
LayoutC
,
typename
AElementOp
=
ck_tile
::
identity
,
typename
BElementOp
=
ck_tile
::
identity
,
typename
ACCElementOp
=
ck_tile
::
identity
>
CK_TILE_HOST
void
reference_gemm
(
const
HostTensor
<
ADataType
>&
a_m_k
,
const
HostTensor
<
BDataType
>&
b_
n_k
,
const
HostTensor
<
BDataType
>&
b_
k_n
,
HostTensor
<
CDataType
>&
c_m_n
,
const
AElementOp
&
a_element_op
=
{},
const
BElementOp
&
b_element_op
=
{},
const
ACCElementOp
&
acc_element_op
=
{})
{
const
int
N
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_n_k
.
mDesc
.
get_lengths
()[
0
]
:
b_n_k
.
mDesc
.
get_lengths
()[
1
];
const
int
K
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
1
]
:
a_m_k
.
mDesc
.
get_lengths
()[
0
];
const
int
M
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_m_k
.
mDesc
.
get_lengths
()[
0
]
:
a_m_k
.
mDesc
.
get_lengths
()[
1
];
auto
f
=
[
&
](
auto
m
)
{
for
(
int
n
=
0
;
n
<
N
;
++
n
)
const
std
::
size_t
M
=
a_m_k
.
get_length
(
0
);
const
std
::
size_t
N
=
b_k_n
.
get_length
(
1
);
const
std
::
size_t
K
=
a_m_k
.
get_length
(
1
);
auto
f_mn
=
[
&
](
auto
m
,
auto
n
)
{
AccDataType
v_acc
=
0
;
for
(
std
::
size_t
k
=
0
;
k
<
K
;
++
k
)
{
AccDataType
v_acc
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
ADataType
v_a
=
(
std
::
is_same_v
<
LayoutA
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
a_element_op
(
a_m_k
(
m
,
k
))
:
a_element_op
(
a_m_k
(
k
,
m
));
BDataType
v_b
=
(
std
::
is_same_v
<
LayoutB
,
tensor_layout
::
gemm
::
ColumnMajor
>
)
?
b_element_op
(
b_n_k
(
n
,
k
))
:
b_element_op
(
b_n_k
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
CDataType
&
c_ref
=
(
std
::
is_same_v
<
LayoutC
,
tensor_layout
::
gemm
::
RowMajor
>
)
?
c_m_n
(
m
,
n
)
:
c_m_n
(
n
,
m
);
c_ref
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
ADataType
v_a
=
a_element_op
(
a_m_k
(
m
,
k
));
BDataType
v_b
=
b_element_op
(
b_k_n
(
k
,
n
));
v_acc
+=
ck_tile
::
type_convert
<
AccDataType
>
(
v_a
)
*
ck_tile
::
type_convert
<
AccDataType
>
(
v_b
);
}
c_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
CDataType
>
(
acc_element_op
(
v_acc
));
};
make_ParallelTensorFunctor
(
f
,
M
)(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
_mn
,
M
,
N
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
ADataType
,
...
...
include/ck_tile/host/reference/reference_layernorm2d_fwd.hpp
View file @
7d50244e
...
...
@@ -8,20 +8,44 @@
namespace
ck_tile
{
// Note: for simplicity, each functor only care about single M
struct
reference_layernorm2d_default_epilogue
{
template
<
typename
OutDataType
,
typename
AccDataType
>
void
operator
()(
int
m
,
HostTensor
<
OutDataType
>&
o
,
const
HostTensor
<
AccDataType
>&
acc
)
{
const
int
N
=
acc
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
o
(
m
,
n
)
=
ck_tile
::
type_convert
<
OutDataType
>
(
acc
(
m
,
n
));
}
}
template
<
typename
OutDataType
,
typename
AccDataType
>
auto
operator
()(
int
m
,
const
HostTensor
<
AccDataType
>&
acc
)
{
HostTensor
<
OutDataType
>
o
(
acc
.
get_lengths
(),
acc
.
get_strides
());
operator
()(
m
,
o
,
acc
);
return
o
;
}
};
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
BetaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
MeanDataType
,
typename
InvStdDataType
>
typename
InvStdDataType
,
typename
Epilogue
=
reference_layernorm2d_default_epilogue
>
void
reference_layernorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
const
HostTensor
<
BetaDataType
>&
beta_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
MeanDataType
>&
mean_m
,
HostTensor
<
InvStdDataType
>&
invStd_m
,
ComputeDataType
epsilon
)
ComputeDataType
epsilon
,
Epilogue
epilogue_functor
=
{})
{
auto
layernorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
...
...
@@ -51,16 +75,19 @@ void reference_layernorm2d_fwd(const HostTensor<XDataType>& x_m_n,
if
constexpr
(
!
std
::
is_same_v
<
InvStdDataType
,
ck_tile
::
null_type
>
)
invStd_m
(
m
)
=
ck_tile
::
type_convert
<
InvStdDataType
>
(
divisor
);
HostTensor
<
ComputeDataType
>
acc
(
x_m_n
.
get_lengths
(),
x_m_n
.
get_strides
());
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
ComputeDataType
beta
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
beta_n
(
n
));
auto
y
=
(
x
-
mean
)
*
divisor
;
y
=
y
*
gamma
+
beta
;
auto
a_
=
(
x
-
mean
)
*
divisor
;
a_
=
a_
*
gamma
+
beta
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
)
;
acc
(
m
,
n
)
=
a_
;
}
epilogue_functor
(
m
,
y_m_n
,
acc
);
};
make_ParallelTensorFunctor
(
layernorm2d_fwd_func
,
...
...
include/ck_tile/host/reference/reference_permute.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
namespace
ck_tile
{
/*
this will do permute + contiguous like functionality in pytorch
*/
template
<
typename
DataType
>
CK_TILE_HOST
void
reference_permute
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y
,
std
::
vector
<
index_t
>
dims
)
{
const
auto
x_len
=
x
.
mDesc
.
get_lengths
();
const
auto
y_len
=
y
.
mDesc
.
get_lengths
();
assert
(
x_len
.
size
()
==
y_len
.
size
());
index_t
rank
=
x_len
.
size
();
const
auto
x_elm
=
std
::
accumulate
(
x_len
.
begin
(),
x_len
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
const
auto
y_elm
=
std
::
accumulate
(
y_len
.
begin
(),
y_len
.
end
(),
1
,
std
::
multiplies
<
index_t
>
());
assert
(
x_elm
==
y_elm
);
(
void
)
y_elm
;
auto
f
=
[
&
](
auto
i_element
)
{
std
::
vector
<
size_t
>
y_coord
=
[
&
]()
{
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
size_t
r
=
i_element
;
for
(
index_t
i
=
rank
-
1
;
i
>=
0
;
i
--
)
{
tmp
[
i
]
=
r
%
y_len
[
i
];
r
=
r
/
y_len
[
i
];
}
return
tmp
;
}();
std
::
vector
<
size_t
>
x_coord
=
[
&
]()
{
std
::
vector
<
size_t
>
tmp
(
rank
,
0
);
for
(
index_t
i
=
0
;
i
<
rank
;
i
++
)
{
tmp
[
dims
[
i
]]
=
y_coord
[
i
];
}
return
tmp
;
}();
// do permute
y
(
y_coord
)
=
x
(
x_coord
);
};
make_ParallelTensorFunctor
(
f
,
x_elm
)(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_reduce.hpp
View file @
7d50244e
...
...
@@ -9,24 +9,25 @@
namespace
ck_tile
{
template
<
typename
ADataType
,
typename
AccDataType
,
typename
BDataType
>
CK_TILE_HOST
void
reference_reduce
(
const
HostTensor
<
ADataType
>&
a_m_n
,
HostTensor
<
BDataType
>&
b_m
)
template
<
typename
XDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
ReduceOp
>
CK_TILE_HOST
void
reference_reduce
(
const
HostTensor
<
XDataType
>&
x_m_n
,
HostTensor
<
YDataType
>&
y_m
,
ReduceOp
reduce_op
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
a
_m_n
.
mDesc
.
get_lengths
()[
1
];
const
int
N
=
x
_m_n
.
mDesc
.
get_lengths
()[
1
];
Acc
DataType
v_acc
=
0
;
Compute
DataType
v_acc
=
reduce_op
.
template
GetIdentityValue
<
ComputeDataType
>()
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
A
DataType
v_a
=
a
_m_n
(
m
,
n
);
const
Compute
DataType
v_a
=
type_convert
<
ComputeDataType
>
(
x
_m_n
(
m
,
n
)
)
;
v_acc
+
=
v_a
;
v_acc
=
reduce_op
(
v_acc
,
v_a
)
;
}
b
_m
(
m
)
=
ck_tile
::
type_convert
<
B
DataType
>
(
v_acc
);
y
_m
(
m
)
=
ck_tile
::
type_convert
<
Y
DataType
>
(
v_acc
);
};
make_ParallelTensorFunctor
(
f
,
b
_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
y
_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_rmsnorm2d_fwd.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
GammaDataType
,
typename
ComputeDataType
,
typename
YDataType
,
typename
InvRmsDataType
>
void
reference_rmsnorm2d_fwd
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
GammaDataType
>&
gamma_n
,
HostTensor
<
YDataType
>&
y_m_n
,
HostTensor
<
InvRmsDataType
>&
invRms_m
,
ComputeDataType
epsilon
)
{
auto
rmsnorm2d_fwd_func
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
ComputeDataType
mean_square
=
0
;
ComputeDataType
divisor
=
0
;
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
mean_square
+=
x
*
x
;
}
mean_square
=
mean_square
/
N
;
divisor
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
1
)
/
ck_tile
::
sqrt
(
mean_square
+
epsilon
);
if
constexpr
(
!
std
::
is_same_v
<
InvRmsDataType
,
ck_tile
::
null_type
>
)
invRms_m
(
m
)
=
ck_tile
::
type_convert
<
InvRmsDataType
>
(
divisor
);
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
ComputeDataType
x
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
x_m_n
(
m
,
n
));
ComputeDataType
gamma
=
ck_tile
::
type_convert
<
ComputeDataType
>
(
gamma_n
(
n
));
auto
y
=
x
*
divisor
*
gamma
;
y_m_n
(
m
,
n
)
=
ck_tile
::
type_convert
<
YDataType
>
(
y
);
}
};
make_ParallelTensorFunctor
(
rmsnorm2d_fwd_func
,
invRms_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_rowwise_quantization2d.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
namespace
ck_tile
{
template
<
typename
XDataType
,
typename
ScaleDataType
,
typename
QXDataType
>
CK_TILE_HOST
void
reference_rowwise_quantization2d
(
const
HostTensor
<
XDataType
>&
x_m_n
,
const
HostTensor
<
ScaleDataType
>&
scale_m
,
HostTensor
<
QXDataType
>&
qx_m_n
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
x_m_n
.
mDesc
.
get_lengths
()[
1
];
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
auto
v_x
=
x_m_n
(
m
,
n
);
// scale = amax / 127 for int8
auto
v_scale
=
type_convert
<
XDataType
>
(
scale_m
(
m
));
auto
v_qx
=
v_x
/
v_scale
;
qx_m_n
(
m
,
n
)
=
saturates
<
QXDataType
>
{}(
v_qx
);
}
};
make_ParallelTensorFunctor
(
f
,
scale_m
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_softmax.hpp
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 20
18-2023
, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 20
24
, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
...
...
@@ -9,43 +9,81 @@
namespace
ck_tile
{
template
<
typename
AData
Type
,
typename
AccData
Type
,
typename
BData
Type
>
CK_TILE_HOST
void
reference_softmax
(
const
HostTensor
<
ADataType
>&
a_m_n
,
HostTensor
<
BData
Type
>&
b_m_n
)
template
<
typename
Input
Type
,
typename
Compute
Type
,
typename
OutputType
=
Compute
Type
>
CK_TILE_HOST
void
reference_softmax
(
const
HostTensor
<
InputType
>&
x
,
HostTensor
<
Output
Type
>&
y
,
index_t
dim
=
-
1
)
{
auto
f
=
[
&
](
auto
m
)
{
const
int
N
=
a_m_n
.
mDesc
.
get_lengths
()[
1
];
index_t
rank
=
x
.
get_num_of_dimension
();
assert
(
rank
==
y
.
get_num_of_dimension
());
assert
(
dim
==
-
1
||
dim
<
rank
);
AccDataType
v_max
=
ck_tile
::
numeric
<
ADataType
>::
Lowest
();
index_t
target_dim
=
dim
==
-
1
?
(
rank
-
1
)
:
dim
;
index_t
softmax_len
=
x
.
get_length
(
target_dim
);
index_t
n_parallel
=
x
.
get_element_size
()
/
softmax_len
;
auto
x_len
=
x
.
get_lengths
();
// max
for
(
int
n
=
0
;
n
<
N
;
++
n
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
f
=
[
&
](
auto
i_element
)
{
std
::
vector
<
size_t
>
coord
=
[
&
]()
{
std
::
vector
<
size_t
>
t_
(
rank
,
0
);
size_t
r
=
i_element
;
for
(
index_t
i
=
rank
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
==
target_dim
)
continue
;
t_
[
i
]
=
r
%
x_len
[
i
];
r
=
r
/
x_len
[
i
];
}
return
t_
;
}();
ComputeType
v_max
=
-
ck_tile
::
numeric
<
ComputeType
>::
infinity
();
v_max
=
v_max
<
v_a
?
v_a
:
v_max
;
// compute max
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
v_max
=
v_max
<
v_x
?
v_x
:
v_max
;
}
AccData
Type
v_exp_sum
=
0
;
Compute
Type
v_exp_sum
=
static_cast
<
ComputeType
>
(
0
)
;
// sum
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
v_exp_sum
+=
ck_tile
::
exp
(
v_a
-
v_max
);
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
v_exp_sum
+=
ck_tile
::
exp
(
v_x
-
v_max
);
}
// elementwise
for
(
int
n
=
0
;
n
<
N
;
++
n
)
for
(
auto
idx
=
0
;
idx
<
softmax_len
;
idx
++
)
{
const
ADataType
v_a
=
a_m_n
(
m
,
n
);
auto
c_
=
coord
;
c_
[
target_dim
]
=
idx
;
const
ComputeType
v_x
=
ck_tile
::
type_convert
<
ComputeType
>
(
x
(
c_
));
auto
out
=
ck_tile
::
exp
(
v_x
-
v_max
)
/
v_exp_sum
;
b_m_n
(
m
,
n
)
=
ck_tile
::
exp
(
v_a
-
v_max
)
/
v_exp_sum
;
y
(
c_
)
=
ck_tile
::
type_convert
<
OutputType
>
(
out
)
;
}
};
make_ParallelTensorFunctor
(
f
,
b_m_n
.
mDesc
.
get_lengths
()[
0
])(
std
::
thread
::
hardware_concurrency
());
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
}
template
<
typename
InputType
,
typename
ComputeType
,
typename
OutputType
=
ComputeType
>
CK_TILE_HOST
auto
reference_softmax
(
const
HostTensor
<
InputType
>&
x
,
index_t
dim
=
-
1
)
{
HostTensor
<
OutputType
>
y
(
x
.
get_lengths
(),
x
.
get_strides
());
reference_softmax
<
InputType
,
ComputeType
,
OutputType
>
(
x
,
y
,
dim
);
return
y
;
}
}
// namespace ck_tile
include/ck_tile/host/reference/reference_topk.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/host/host_tensor.hpp"
#include <thread>
#include <numeric>
#include <functional>
#include <utility>
#include <algorithm>
namespace
ck_tile
{
/*
similiar to torch.topk()
x (Tensor) – the input tensor.
k (int) – the k in “top-k”
dim (int, optional) – the dimension to sort along
largest (bool, optional) – largest or smallest elements
sorted (bool, optional) – elements in sorted order or not
output:
y_values
y_indices
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/TopKImpl.h
*/
template
<
typename
DataType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
void
reference_topk
(
const
HostTensor
<
DataType
>&
x
,
HostTensor
<
DataType
>&
y_values
,
HostTensor
<
IndexType
>&
y_indices
,
index_t
k
,
index_t
dim
=
-
1
,
bool
largest
=
true
,
bool
sorted
=
true
)
{
// rank must be the same
index_t
rank
=
x
.
get_num_of_dimension
();
assert
(
rank
==
y_values
.
get_num_of_dimension
());
assert
(
rank
==
y_indices
.
get_num_of_dimension
());
assert
(
dim
==
-
1
||
dim
<
rank
);
index_t
topk_dim
=
dim
==
-
1
?
(
rank
-
1
)
:
dim
;
index_t
topk_src_len
=
x
.
get_length
(
topk_dim
);
auto
x_len
=
x
.
get_lengths
();
assert
(
k
<=
topk_src_len
);
assert
(
k
==
y_values
.
get_length
(
topk_dim
)
&&
k
==
y_indices
.
get_length
(
topk_dim
));
index_t
n_parallel
=
x
.
get_element_size
()
/
topk_src_len
;
// clang-format off
auto
f
=
[
&
](
auto
i_element
)
{
std
::
vector
<
size_t
>
topk_coord
=
[
&
](){
std
::
vector
<
size_t
>
t_
(
rank
,
0
);
size_t
r
=
i_element
;
for
(
index_t
i
=
rank
-
1
;
i
>=
0
;
i
--
)
{
if
(
i
==
topk_dim
)
continue
;
// topk dim should be zero
t_
[
i
]
=
r
%
x_len
[
i
];
r
=
r
/
x_len
[
i
];
}
return
t_
;
}();
using
elem_t
=
std
::
pair
<
DataType
,
IndexType
>
;
std
::
vector
<
elem_t
>
q
=
[
&
](){
std
::
vector
<
elem_t
>
t_
(
topk_src_len
);
for
(
index_t
i
=
0
;
i
<
topk_src_len
;
i
++
)
{
auto
c_
=
topk_coord
;
c_
[
topk_dim
]
=
i
;
t_
[
i
].
first
=
x
(
c_
);
t_
[
i
].
second
=
i
;
}
return
t_
;
}();
// run topk
if
(
largest
)
{
std
::
nth_element
(
q
.
begin
(),
q
.
begin
()
+
k
-
1
,
q
.
end
(),
[](
const
elem_t
&
lhs
,
const
elem_t
&
rhs
)
->
bool
{
return
lhs
.
first
>
rhs
.
first
;
});
if
(
sorted
)
{
std
::
sort
(
q
.
begin
(),
q
.
begin
()
+
k
-
1
,
[](
const
elem_t
&
lhs
,
const
elem_t
&
rhs
)
->
bool
{
return
lhs
.
first
>
rhs
.
first
;
});
}
}
else
{
std
::
nth_element
(
q
.
begin
(),
q
.
begin
()
+
k
-
1
,
q
.
end
(),
[](
const
elem_t
&
lhs
,
const
elem_t
&
rhs
)
->
bool
{
return
lhs
.
first
<
rhs
.
first
;
});
if
(
sorted
)
{
std
::
sort
(
q
.
begin
(),
q
.
begin
()
+
k
-
1
,
[](
const
elem_t
&
lhs
,
const
elem_t
&
rhs
)
->
bool
{
return
lhs
.
first
<
rhs
.
first
;
});
}
}
// write out
for
(
index_t
i
=
0
;
i
<
k
;
i
++
)
{
auto
c_
=
topk_coord
;
c_
[
topk_dim
]
=
i
;
y_values
(
c_
)
=
q
[
i
].
first
;
y_indices
(
c_
)
=
q
[
i
].
second
;
}
};
// clang-format on
make_ParallelTensorFunctor
(
f
,
n_parallel
)(
std
::
thread
::
hardware_concurrency
());
}
// TODO: if using this method, the return tensor would be dense(no stride)
template
<
typename
DataType
,
typename
IndexType
=
index_t
>
CK_TILE_HOST
auto
reference_topk
(
const
HostTensor
<
DataType
>&
x
,
index_t
k
,
index_t
dim
=
-
1
,
bool
largest
=
true
,
bool
sorted
=
true
)
{
auto
lens
=
x
.
get_lengths
();
index_t
target_dim
=
(
dim
==
-
1
)
?
(
lens
.
size
()
-
1
)
:
dim
;
assert
(
target_dim
<
lens
.
size
());
assert
(
k
<=
lens
[
target_dim
]);
lens
[
target_dim
]
=
k
;
HostTensor
<
DataType
>
y_values
(
lens
);
HostTensor
<
IndexType
>
y_indices
(
lens
);
reference_topk
<
DataType
,
IndexType
>
(
x
,
y_values
,
y_indices
,
k
,
dim
,
largest
,
sorted
);
return
ck_tile
::
make_tuple
(
y_values
,
y_indices
);
}
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp"
#include "ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp"
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_kernel.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
namespace
ck_tile
{
// host side args
struct
AddRmsnorm2dRdquantFwdHostArgs
{
const
void
*
p_a
;
const
void
*
p_b
;
const
void
*
p_gamma
;
void
*
p_x
;
void
*
p_yscale
;
void
*
p_qy
;
float
epsilon
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
// TODO: Extract some type to wrapper class
template
<
typename
Pipeline_
>
struct
AddRmsnorm2dRdquantFwd
{
using
Pipeline
=
remove_cvref_t
<
Pipeline_
>
;
using
Problem
=
typename
Pipeline
::
Problem
;
using
ADataType
=
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
GammaDataType
=
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
XDataType
=
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
YScaleDataType
=
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
index_t
Block_M
=
Problem
::
BlockShape
::
Block_M
;
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
static
constexpr
bool
kPadM
=
false
;
// always no need to pad along M
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
bool
kThreePass
=
Problem
::
kThreePass
;
static
constexpr
index_t
ThreadPerWarp_N
=
Problem
::
BlockShape
::
ThreadPerWarp_N
;
static
constexpr
index_t
Vector_N
=
Problem
::
BlockShape
::
Vector_N
;
static
constexpr
index_t
Repeat_N
=
Problem
::
BlockShape
::
Repeat_N
;
static
constexpr
auto
I0
=
number
<
0
>
{};
static
constexpr
auto
I1
=
number
<
1
>
{};
struct
Kargs
{
const
void
*
p_a
;
const
void
*
p_b
;
const
void
*
p_gamma
;
void
*
p_x
;
void
*
p_yscale
;
void
*
p_qy
;
float
epsilon
;
index_t
m
;
index_t
n
;
index_t
stride
;
// row_stride
};
using
Hargs
=
AddRmsnorm2dRdquantFwdHostArgs
;
CK_TILE_HOST
static
constexpr
Kargs
MakeKargs
(
const
Hargs
&
hargs
)
{
return
Kargs
{
hargs
.
p_a
,
hargs
.
p_b
,
hargs
.
p_gamma
,
hargs
.
p_x
,
hargs
.
p_yscale
,
hargs
.
p_qy
,
hargs
.
epsilon
,
hargs
.
m
,
hargs
.
n
,
hargs
.
stride
};
}
CK_TILE_HOST
static
constexpr
auto
GridSize
(
const
Hargs
&
hargs
)
{
return
integer_divide_ceil
(
hargs
.
m
,
Block_M
);
}
CK_TILE_HOST
static
constexpr
auto
BlockSize
()
{
return
Problem
::
BlockShape
::
BlockSize
;
}
// clang-format off
template
<
typename
T
>
struct
t2s
;
template
<
>
struct
t2s
<
float
>
{
static
constexpr
const
char
*
name
=
"fp32"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp16_t
>
{
static
constexpr
const
char
*
name
=
"fp16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf16_t
>
{
static
constexpr
const
char
*
name
=
"bf16"
;
};
template
<
>
struct
t2s
<
ck_tile
::
fp8_t
>
{
static
constexpr
const
char
*
name
=
"fp8"
;
};
template
<
>
struct
t2s
<
ck_tile
::
bf8_t
>
{
static
constexpr
const
char
*
name
=
"bf8"
;
};
// clang-format on
// in byte
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Pipeline
::
GetSmemSize
();
}
CK_TILE_HOST
static
std
::
string
GetName
()
{
// clang-format off
using
S_
=
typename
Problem
::
BlockShape
;
auto
surfix
=
[
&
]
()
{
std
::
string
n
;
if
(
kPadN
)
n
+=
"_pn"
;
if
(
kSaveX
)
n
+=
"_x"
;
if
(
kThreePass
)
n
+=
"_2p"
;
return
n
;
}();
#define _SS_ std::string
#define _TS_ std::to_string
return
_SS_
(
"add_rmsnorm2d_rdquant_fwd_"
)
+
_SS_
(
t2s
<
XDataType
>::
name
)
+
"_"
+
_TS_
(
S_
::
Block_M
)
+
"x"
+
_TS_
(
S_
::
Block_N
)
+
"_"
+
_TS_
(
S_
::
WarpPerBlock_M
)
+
"x"
+
_TS_
(
S_
::
WarpPerBlock_N
)
+
"_"
+
_TS_
(
S_
::
Warp_M
)
+
"x"
+
_TS_
(
S_
::
Warp_N
)
+
"_"
+
_TS_
(
S_
::
Vector_M
)
+
"x"
+
_TS_
(
S_
::
Vector_N
)
+
"_"
+
_SS_
(
Pipeline
::
name
)
+
surfix
;
#undef _SS_
#undef _TS_
// clang-format on
}
CK_TILE_DEVICE
void
operator
()(
Kargs
kargs
)
const
{
const
auto
iM
=
get_block_id
()
*
Block_M
;
const
auto
a_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
ADataType
*>
(
kargs
.
p_a
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
b_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
BDataType
*>
(
kargs
.
p_b
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
const
auto
gamma_window
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
const
GammaDataType
*>
(
kargs
.
p_gamma
),
make_tuple
(
kargs
.
n
),
make_tuple
(
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
const
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_N
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_N
>
{}),
{
0
});
}();
auto
x_window
=
[
&
]()
{
if
constexpr
(
kSaveX
)
{
const
auto
tmp2_
=
[
&
]()
{
const
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
XDataType
*>
(
kargs
.
p_x
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
return
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
}();
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}
else
return
make_null_tile_window
(
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}));
}();
auto
yscale_window
=
[
&
]()
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
YScaleDataType
*>
(
kargs
.
p_yscale
),
make_tuple
(
kargs
.
m
),
make_tuple
(
1
),
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{}),
sequence
<
kPadM
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{}),
{
iM
});
}();
auto
qy_window
=
[
&
]()
{
auto
tmp_
=
make_naive_tensor_view
<
address_space_enum
::
global
>
(
static_cast
<
QYDataType
*>
(
kargs
.
p_qy
),
make_tuple
(
kargs
.
m
,
kargs
.
n
),
make_tuple
(
kargs
.
stride
,
1
),
number
<
Vector_N
>
{},
number
<
1
>
{});
auto
tmp2_
=
pad_tensor_view
(
tmp_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
sequence
<
kPadM
,
kPadN
>
{});
return
make_tile_window
(
tmp2_
,
make_tuple
(
number
<
Block_M
>
{},
number
<
Block_N
>
{}),
{
iM
,
0
});
}();
__shared__
char
smem
[
GetSmemSize
()];
Pipeline
{}(
a_window
,
b_window
,
gamma_window
,
x_window
,
yscale_window
,
qy_window
,
static_cast
<
const
ComputeDataType
>
(
kargs
.
epsilon
),
kargs
.
n
,
smem
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/kernel/add_rmsnorm2d_rdquant_fwd_shape.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
AddRmsnorm2dRdquantShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_default_policy.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d_problem.hpp"
#include "ck_tile/ops/reduce/block/block_reduce2d.hpp"
namespace
ck_tile
{
struct
AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
{
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeABXBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<>
,
tuple
<
sequence
<
S
::
Repeat_M
,
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
,
S
::
Vector_M
>
,
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
1
,
2
>
,
sequence
<
1
,
2
>>
,
tuple
<
sequence
<
1
,
1
>
,
sequence
<
2
,
2
>>
,
sequence
<
1
,
1
,
2
,
2
>
,
sequence
<
0
,
3
,
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_DEVICE
static
constexpr
auto
MakeGammaBlockTileDistribution
()
{
using
S
=
typename
Problem
::
BlockShape
;
return
make_static_tile_distribution
(
tile_distribution_encoding
<
sequence
<
S
::
WarpPerBlock_M
,
S
::
ThreadPerWarp_M
>
,
tuple
<
sequence
<
S
::
Repeat_N
,
S
::
WarpPerBlock_N
,
S
::
ThreadPerWarp_N
,
S
::
Vector_N
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
0
,
1
>>
,
tuple
<
sequence
<
0
,
1
>
,
sequence
<
1
,
2
>>
,
sequence
<
1
,
1
>
,
sequence
<
0
,
3
>>
{});
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2d
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2d
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
auto
GetBlockReduce2dCrossWarpSync
()
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
return
BlockReduce2dCrossWarpSync
<
P_
>
{};
}
template
<
typename
Problem
>
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
if
constexpr
(
Problem
::
kNeedCrossWarpSync
)
{
using
P_
=
BlockReduce2dProblem
<
typename
Problem
::
ComputeDataType
,
typename
Problem
::
ComputeDataType
,
typename
Problem
::
BlockShape
>
;
using
block_reduce2d
=
BlockReduce2d
<
P_
>
;
using
x_block_tile
=
decltype
(
make_static_distributed_tensor
<
typename
Problem
::
ComputeDataType
>
(
MakeABXBlockTileDistribution
<
Problem
>
()));
using
y_block_tile
=
decltype
(
block_reduce2d
::
template
MakeYBlockTile
<
x_block_tile
>());
return
GetBlockReduce2dCrossWarpSync
<
Problem
>
().
template
GetSmemSize
<
y_block_tile
>();
}
else
{
return
1
;
// zero size arrays are an extension
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_one_pass.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
>
struct
AddRmsnorm2dRdquantFwdPipelineOnePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_op"
;
// block per row
else
return
"wpr_op"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
AWindow
,
typename
BWindow
,
typename
GammaWindow
,
typename
XWindow
,
typename
YScaleWindow
,
typename
QYWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
BWindow
&
b_window_
,
const
GammaWindow
&
gamma_window_
,
XWindow
&
x_window
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
const
auto
a_window
=
make_tile_window
(
a_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
const
auto
b_window
=
make_tile_window
(
b_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
const
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
const
auto
gamma
=
load_tile
(
gamma_window
);
auto
x
=
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
if
constexpr
(
kSaveX
)
store_tile
(
x_window
,
cast_tile
<
XDataType
>
(
x
));
// compute mean square, each-thread->cross-lane->cross-warp
auto
square_sum
=
block_reduce2d
(
x
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_square_sum_func
);
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
},
square_sum
);
// rmsnorm computation
auto
y
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
,
inv_rms_
=
inv_rms
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms_
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
// compute absmax, each-thread->cross-lane->cross-warp
auto
absmax
=
block_reduce2d
(
y
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
(),
reduce_absmax_func
);
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
// ex: yscale = absmax / 127 if int8
auto
yscale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
ComputeDataType
>
(
numeric
<
QYDataType
>::
max
());
},
absmax
);
store_tile
(
yscale_window
,
cast_tile
<
YScaleDataType
>
(
yscale
));
// quantize y to qy
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
qy
,
[
&
,
yscale_
=
yscale
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
auto
qy_
=
y
[
idx
]
/
yscale_
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
});
store_tile
(
qy_window
,
qy
);
}
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_problem.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core/utility/type_traits.hpp"
namespace
ck_tile
{
// X = A + B, Y = Rmsnorm2d(X), QY = RowwiseDynamicQuant(Y) = SaturateCast(Y / YScale)
template
<
typename
ADataType_
,
typename
BDataType_
,
typename
GammaDataType_
,
typename
ComputeDataType_
,
typename
XDataType_
,
typename
YScaleDataType_
,
typename
QYDataType_
,
typename
BlockShape_
,
bool
kPadN_
,
bool
kSaveX_
,
bool
kThreePass_
>
struct
AddRmsnorm2dRdquantFwdPipelineProblem
{
using
ADataType
=
remove_cvref_t
<
ADataType_
>
;
using
BDataType
=
remove_cvref_t
<
BDataType_
>
;
using
GammaDataType
=
remove_cvref_t
<
GammaDataType_
>
;
using
ComputeDataType
=
remove_cvref_t
<
ComputeDataType_
>
;
using
XDataType
=
remove_cvref_t
<
XDataType_
>
;
using
YScaleDataType
=
remove_cvref_t
<
YScaleDataType_
>
;
using
QYDataType
=
remove_cvref_t
<
QYDataType_
>
;
using
BlockShape
=
remove_cvref_t
<
BlockShape_
>
;
static
constexpr
bool
kNeedCrossLaneSync
=
BlockShape
::
ThreadPerWarp_N
>
1
;
static
constexpr
bool
kNeedCrossWarpSync
=
BlockShape
::
WarpPerBlock_N
>
1
;
static
constexpr
bool
kPadN
=
kPadN_
;
static
constexpr
bool
kSaveX
=
kSaveX_
;
static
constexpr
bool
kThreePass
=
kThreePass_
;
};
}
// namespace ck_tile
include/ck_tile/ops/add_rmsnorm2d_rdquant/pipeline/add_rmsnorm2d_rdquant_fwd_pipeline_three_pass.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/rmsnorm2d/pipeline/rmsnorm2d_fwd_pipeline_default_policy.hpp"
#include <string>
#include <type_traits>
namespace
ck_tile
{
template
<
typename
Problem_
,
typename
Policy_
=
AddRmsnorm2dRdquantFwdPipelineDefaultPolicy
>
struct
AddRmsnorm2dRdquantFwdPipelineThreePass
{
using
Problem
=
ck_tile
::
remove_cvref_t
<
Problem_
>
;
using
Policy
=
ck_tile
::
remove_cvref_t
<
Policy_
>
;
using
ADataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ADataType
>
;
using
BDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
BDataType
>
;
using
GammaDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
GammaDataType
>
;
using
ComputeDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
ComputeDataType
>
;
using
XDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
XDataType
>
;
using
YScaleDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
YScaleDataType
>
;
using
QYDataType
=
ck_tile
::
remove_cvref_t
<
typename
Problem
::
QYDataType
>
;
static
constexpr
bool
kHasGamma
=
!
std
::
is_same_v
<
GammaDataType
,
ck_tile
::
null_type
>
;
static
constexpr
bool
kSaveX
=
Problem
::
kSaveX
;
static
constexpr
bool
kNeedCrossWarpSync
=
Problem
::
kNeedCrossWarpSync
;
static
constexpr
bool
kPadM
=
false
;
// TODO - BlockAddRmsnorm2dRdquantFwdProblem::kPadM
static
constexpr
bool
kPadN
=
Problem
::
kPadN
;
static
constexpr
const
char
*
name
=
[]()
{
if
constexpr
(
kNeedCrossWarpSync
)
return
"bpr_tp"
;
// block per row
else
return
"wpr_tp"
;
// warp per row
}();
CK_TILE_HOST_DEVICE
static
constexpr
index_t
GetSmemSize
()
{
return
Policy
::
template
GetSmemSize
<
Problem
>();
}
template
<
typename
AWindow
,
typename
BWindow
,
typename
GammaWindow
,
typename
XWindow
,
typename
YScaleWindow
,
typename
QYWindow
>
CK_TILE_DEVICE
auto
operator
()(
const
AWindow
&
a_window_
,
const
BWindow
&
b_window_
,
const
GammaWindow
&
gamma_window_
,
XWindow
&
x_window_
,
YScaleWindow
&
yscale_window
,
QYWindow
&
qy_window
,
ComputeDataType
epsilon
,
ck_tile
::
index_t
row_size
,
void
*
smem
)
const
{
auto
a_window
=
make_tile_window
(
a_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
auto
b_window
=
make_tile_window
(
b_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
auto
x_window
=
[
&
]()
{
if
constexpr
(
kSaveX
)
return
make_tile_window
(
x_window_
,
Policy
::
template
MakeABXBlockTileDistribution
<
Problem
>());
else
return
x_window_
;
}();
auto
gamma_window
=
make_tile_window
(
gamma_window_
,
Policy
::
template
MakeGammaBlockTileDistribution
<
Problem
>());
auto
reduce_square_sum_func
=
ReduceOp
::
SquareAdd
{};
auto
reduce_sum_func
=
ReduceOp
::
Add
{};
auto
reduce_absmax_func
=
ReduceOp
::
AbsMax
{};
auto
reduce_max_func
=
ReduceOp
::
Max
{};
auto
block_reduce2d
=
Policy
::
template
GetBlockReduce2d
<
Problem
>();
auto
block_reduce2d_sync
=
Policy
::
template
GetBlockReduce2dSync
<
Problem
>();
auto
block_reduce2d_cross_warp_sync
=
Policy
::
template
GetBlockReduce2dCrossWarpSync
<
Problem
>();
static
constexpr
index_t
Block_N
=
Problem
::
BlockShape
::
Block_N
;
index_t
num_n_tile_iteration
=
__builtin_amdgcn_readfirstlane
(
integer_divide_ceil
(
row_size
,
Block_N
));
using
XTensorType
=
decltype
(
cast_tile
<
ComputeDataType
>
(
load_tile
(
a_window
)));
auto
square_sum
=
block_reduce2d
.
template
MakeYBlockTile
<
XTensorType
>();
set_tile
(
square_sum
,
reduce_square_sum_func
.
GetIdentityValue
<
ComputeDataType
>
());
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
auto
x
=
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
if
constexpr
(
kSaveX
)
store_tile
(
x_window
,
cast_tile
<
XDataType
>
(
x
));
block_reduce2d
(
x
,
square_sum
,
reduce_square_sum_func
);
move_tile_window
(
x_window
,
{
0
,
Block_N
});
move_tile_window
(
a_window
,
{
0
,
Block_N
});
move_tile_window
(
b_window
,
{
0
,
Block_N
});
}
block_reduce2d_sync
(
square_sum
,
reduce_sum_func
);
block_reduce2d_cross_warp_sync
(
square_sum
,
smem
,
reduce_sum_func
);
auto
inv_rms
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
type_convert
<
ComputeDataType
>
(
1.0
f
)
/
(
sqrt
(
v_
/
row_size
+
epsilon
));
},
square_sum
);
// reverse read x to reuse cache
ck_tile
::
index_t
stride_to_right_most_window
=
row_size
%
Block_N
==
0
?
row_size
-
Block_N
:
row_size
-
row_size
%
Block_N
;
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
-
Block_N
});
move_tile_window
(
b_window
,
{
0
,
-
Block_N
});
}
move_tile_window
(
gamma_window
,
{
stride_to_right_most_window
});
using
YTensorType
=
XTensorType
;
auto
absmax
=
block_reduce2d
.
template
MakeYBlockTile
<
YTensorType
>();
set_tile
(
absmax
,
reduce_absmax_func
.
GetIdentityValue
<
ComputeDataType
>
());
// rmsnorm computation + absmax(threadwise reduce)
if
constexpr
(
kSaveX
)
__syncthreads
();
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
[
&
]()
{
if
constexpr
(
kSaveX
)
{
return
load_tile
(
x_window
);
}
else
{
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
return
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
}
}();
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms
[
i_idx
]
*
gamma_
;
y
(
idx
)
=
type_convert
<
ComputeDataType
>
(
y_
);
});
block_reduce2d
(
y
,
absmax
,
reduce_absmax_func
);
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
-
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
-
Block_N
});
move_tile_window
(
b_window
,
{
0
,
-
Block_N
});
}
move_tile_window
(
gamma_window
,
{
-
Block_N
});
}
// compute absmax, cross-lane->cross-warp
block_reduce2d_sync
(
absmax
,
reduce_max_func
);
block_reduce2d_cross_warp_sync
(
absmax
,
smem
,
reduce_max_func
);
// ex: yscale = absmax / 127 if int8
auto
yscale
=
tile_elementwise_in
(
[
&
](
const
auto
&
v_
)
{
return
v_
/
type_convert
<
ComputeDataType
>
(
numeric
<
QYDataType
>::
max
());
},
absmax
);
store_tile
(
yscale_window
,
cast_tile
<
YScaleDataType
>
(
yscale
));
// quantize y to qy
// recompute rmsnorm, try to save y in the future
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
Block_N
});
move_tile_window
(
b_window
,
{
0
,
Block_N
});
}
move_tile_window
(
gamma_window
,
{
Block_N
});
for
(
int
iN
=
__builtin_amdgcn_readfirstlane
(
0
);
iN
<
num_n_tile_iteration
;
++
iN
)
{
auto
x
=
[
&
]()
{
if
constexpr
(
kSaveX
)
{
return
load_tile
(
x_window
);
}
else
{
const
auto
a
=
load_tile
(
a_window
);
const
auto
b
=
load_tile
(
b_window
);
return
tile_elementwise_in
(
[
&
](
const
auto
&
a_
,
const
auto
&
b_
)
{
return
type_convert
<
ComputeDataType
>
(
a_
)
+
type_convert
<
ComputeDataType
>
(
b_
);
},
a
,
b
);
}
}();
auto
gamma
=
load_tile
(
gamma_window
);
auto
y
=
make_static_distributed_tensor
<
ComputeDataType
>
(
x
.
get_tile_distribution
());
auto
qy
=
make_static_distributed_tensor
<
QYDataType
>
(
y
.
get_tile_distribution
());
sweep_tile
(
y
,
[
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx
[
number
<
0
>
{}]);
constexpr
auto
j_idx
=
make_tuple
(
idx
[
number
<
1
>
{}]);
const
auto
gamma_
=
type_convert
<
ComputeDataType
>
(
gamma
[
j_idx
]);
const
auto
x_
=
type_convert
<
ComputeDataType
>
(
x
[
idx
]);
auto
y_
=
x_
*
inv_rms
[
i_idx
]
*
gamma_
;
auto
qy_
=
y_
/
yscale
[
i_idx
];
qy
(
idx
)
=
saturates
<
QYDataType
>
{}(
qy_
);
});
store_tile
(
qy_window
,
qy
);
if
constexpr
(
kSaveX
)
move_tile_window
(
x_window
,
{
0
,
Block_N
});
else
{
move_tile_window
(
a_window
,
{
0
,
Block_N
});
move_tile_window
(
b_window
,
{
0
,
Block_N
});
}
move_tile_window
(
gamma_window
,
{
Block_N
});
move_tile_window
(
qy_window
,
{
0
,
Block_N
});
}
}
};
}
// namespace ck_tile
include/ck_tile/ops/common.hpp
View file @
7d50244e
...
...
@@ -3,4 +3,5 @@
#pragma once
#include "ck_tile/ops/common/generic_2d_block_shape.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
include/ck_tile/ops/common/generic_2d_block_shape.hpp
0 → 100644
View file @
7d50244e
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
namespace
ck_tile
{
/*
// clang-format off
4-level descriptor: BlockTile-> WarpPerBlock-> WarpTile-> Vector
Block_N (Warp_N * WarpPerBlock_N * Repeat_N )
+<----------------------< Repeat_N(2)>--------------------->+
| |
+<-- <WarpPerBlock_N(2)> -->+
Warp_N
+--------------+--------------+--------------+--------------+----+----------------+
Warp_M | wrap_0 | wrap_1 | | ^ ^
+--------------+--------------+ | <WarpPerBlock_M(2)> |
| wrap_2 | wrap_3 | | v
+--------------+--------------+--------------+--------------+----+ Block_M
| | |
+ + |
| | | v
+--------------+--------------+--------------+--------------+ +
each Warp-tile (e.g 16 thrd per row)
Vector_N (contiguous pixels each thrd holds along N, or vector size)
+-----------+-----------+-----------+-----------+-----------+
| thrd_0 | thrd_1 | thrd_2 | thrd_3 | ... Vector_M
+-----------+-----------+-----------+-----------+-----------+
| thrd_16 | thrd_17 | thrd_18 | thrd_19 | ...
+-----------+-----------+-----------+-----------+-----------+
// clang-format on
*/
template
<
typename
BlockTile_
,
// block size, seq<M, N>
typename
WarpPerBlock_
,
// num warps along seq<M, N>
typename
WarpTile_
,
// warp size, seq<M, N>
typename
Vector_
,
// contiguous pixels(vector size) along seq<M, N>
index_t
BlockSize_
=
warpSize
*
reduce_on_sequence
(
WarpPerBlock_
{}
,
multiplies
{}
,
number
<
1
>{})
>
struct
Generic2dBlockShape
{
// block size
static
constexpr
index_t
Block_M
=
BlockTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Block_N
=
BlockTile_
::
at
(
number
<
1
>
{});
// num warps along seq<M, N>, within each block
static
constexpr
index_t
WarpPerBlock_M
=
WarpPerBlock_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
WarpPerBlock_N
=
WarpPerBlock_
::
at
(
number
<
1
>
{});
// warp size
static
constexpr
index_t
Warp_M
=
WarpTile_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Warp_N
=
WarpTile_
::
at
(
number
<
1
>
{});
static_assert
(
Block_M
%
(
WarpPerBlock_M
*
Warp_M
)
==
0
);
static_assert
(
Block_N
%
(
WarpPerBlock_N
*
Warp_N
)
==
0
);
// repeat of each thread along seq<M, N>
static
constexpr
index_t
Repeat_M
=
Block_M
/
(
WarpPerBlock_M
*
Warp_M
);
static
constexpr
index_t
Repeat_N
=
Block_N
/
(
WarpPerBlock_N
*
Warp_N
);
// vector size along seq<M, N>
static
constexpr
index_t
Vector_M
=
Vector_
::
at
(
number
<
0
>
{});
static
constexpr
index_t
Vector_N
=
Vector_
::
at
(
number
<
1
>
{});
static_assert
(
Warp_M
%
Vector_M
==
0
);
static_assert
(
Warp_N
%
Vector_N
==
0
);
// num of threads along seq<M, N>, within each warp
static
constexpr
index_t
ThreadPerWarp_M
=
Warp_M
/
Vector_M
;
static
constexpr
index_t
ThreadPerWarp_N
=
Warp_N
/
Vector_N
;
static
constexpr
index_t
BlockSize
=
BlockSize_
;
};
}
// namespace ck_tile
Prev
1
…
5
6
7
8
9
10
11
12
13
…
15
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment