Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
e1396d87
"...git@developer.sourcefind.cn:OpenDAS/mmdetection3d.git" did not exist on "f5a5103a95c4578ec112421cb4dcc23f564bb328"
Commit
e1396d87
authored
Sep 17, 2024
by
carlushuang
Browse files
update SweepTile API
parent
6975cb8f
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
310 additions
and
24 deletions
+310
-24
example/ck_tile/04_topk_softmax/topk_softmax_api.cpp
example/ck_tile/04_topk_softmax/topk_softmax_api.cpp
+2
-2
include/ck_tile/core/container/tuple.hpp
include/ck_tile/core/container/tuple.hpp
+20
-0
include/ck_tile/core/tensor/static_distributed_tensor.hpp
include/ck_tile/core/tensor/static_distributed_tensor.hpp
+14
-0
include/ck_tile/core/tensor/sweep_tile.hpp
include/ck_tile/core/tensor/sweep_tile.hpp
+263
-0
include/ck_tile/core/utility/functional_with_tuple.hpp
include/ck_tile/core/utility/functional_with_tuple.hpp
+2
-0
include/ck_tile/ops/fmha.hpp
include/ck_tile/ops/fmha.hpp
+2
-2
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
+7
-20
No files found.
example/ck_tile/04_topk_softmax/topk_softmax_api.cpp
View file @
e1396d87
...
@@ -51,9 +51,9 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
...
@@ -51,9 +51,9 @@ float topk_softmax(topk_softmax_trait t, topk_softmax_kargs a, ck_tile::stream_c
TOPK_SOFTMAX_DISPATCH
(
192
)
TOPK_SOFTMAX_DISPATCH
(
192
)
}
}
#else
#else
if
(
t
.
experts
<=
1
6
)
if
(
t
.
experts
<=
1
28
)
{
{
TOPK_SOFTMAX_DISPATCH
(
1
6
)
TOPK_SOFTMAX_DISPATCH
(
1
28
)
}
}
#endif
#endif
}
}
...
...
include/ck_tile/core/container/tuple.hpp
View file @
e1396d87
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
...
@@ -488,6 +488,26 @@ CK_TILE_HOST_DEVICE constexpr auto transform_tuples(F f, const X& x, const Y& y,
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
f
,
x
,
y
,
z
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
}
namespace
detail
{
template
<
typename
F
,
typename
X
,
index_t
...
Is
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples_impl
(
F
f
,
const
X
&
x
,
sequence
<
Is
...
>
)
{
return
concat_tuple
(
f
(
x
.
at
(
number
<
Is
>
{}))...);
}
}
// namespace detail
// make sure F return at least a tuple
// e.g. x : tuple<X, Y>, f will return tuple<Z, W>
// this function will return
template
<
typename
F
,
typename
X
>
CK_TILE_HOST_DEVICE
constexpr
auto
embed_tuples
(
F
f
,
const
X
&
x
)
{
return
detail
::
embed_tuples_impl
(
f
,
x
,
typename
arithmetic_sequence_gen
<
0
,
X
::
size
(),
1
>::
type
{});
}
// By default unroll to the flatten
// By default unroll to the flatten
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
template
<
index_t
Depth
=
0
,
index_t
MaxDepth
=
-
1
>
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
CK_TILE_HOST_DEVICE
constexpr
auto
unroll_nested_tuple
(
const
tuple
<>&
t
)
...
...
include/ck_tile/core/tensor/static_distributed_tensor.hpp
View file @
e1396d87
...
@@ -187,6 +187,20 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
...
@@ -187,6 +187,20 @@ set_tile_if(static_distributed_tensor<DataType, StaticTileDistribution>& out_ten
});
});
}
}
// this function used inside span loop over
template
<
typename
YLengths
,
index_t
XUnpacks
>
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks_from_x_unpacks
(
YLengths
,
number
<
XUnpacks
>
)
{
constexpr
auto
y_size
=
reduce_on_sequence
(
YLengths
{},
multiplies
{},
number
<
1
>
{});
constexpr
auto
y_packs
=
number
<
XUnpacks
>
{};
static_assert
(
y_size
%
y_packs
==
0
);
constexpr
auto
y_slice_size
=
y_size
/
y_packs
;
constexpr
auto
slice_info
=
slice_sequence
(
YLengths
{},
number
<
y_slice_size
>
{});
constexpr
auto
unpacks
=
slice_info
[
number
<
1
>
{}];
return
unpacks
;
}
namespace
detail
{
namespace
detail
{
// check if 2 static_distributed_tensor has same data type and size of element
// check if 2 static_distributed_tensor has same data type and size of element
...
...
include/ck_tile/core/tensor/sweep_tile.hpp
View file @
e1396d87
...
@@ -42,4 +42,267 @@ CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks =
...
@@ -42,4 +42,267 @@ CK_TILE_DEVICE void sweep_tile_uspan(TileDistributedSpan_, const F& f, Unpacks =
[
&
](
auto
...
dstr_idx_impl
)
{
f
(
detail
::
make_tile_distributed_index
(
dstr_idx_impl
)...);
});
[
&
](
auto
...
dstr_idx_impl
)
{
f
(
detail
::
make_tile_distributed_index
(
dstr_idx_impl
)...);
});
}
}
namespace
impl
{
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl
;
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
i_idx
))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
const
auto
next_span_idx
=
embed_tuples
(
[
&
](
auto
si
)
{
return
make_tuple
(
concat_tuple
(
si
,
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
)))...);
},
span_idx
);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
>
struct
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<>>
{
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
return
1
;
}
template
<
typename
F
,
typename
SpanIdx
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
)
const
{
unpack
(
f
,
span_idx
);
}
template
<
typename
F
,
typename
SpanIdx
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
const
SpanIdx
&
span_idx
,
number
<
i_access
>
)
const
{
unpack
(
f
,
span_idx
);
}
};
template
<
typename
,
typename
,
typename
>
struct
sweep_tile_impl_0
;
// TODO: support empty tuple to remove this "entry-point" like function
template
<
typename
DistributedTensor
,
typename
UnpacksPerXDim
,
index_t
I
,
index_t
...
Is
>
struct
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
I
,
Is
...
>>
{
CK_TILE_HOST_DEVICE
constexpr
auto
get_y_unpacks
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
y_lengths
=
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
{};
constexpr
auto
x_unpacks
=
number
<
UnpacksPerXDim
{}.
at
(
number
<
I
>
{})
>
{};
constexpr
auto
y_unpacks
=
get_y_unpacks_from_x_unpacks
(
y_lengths
,
x_unpacks
);
return
y_unpacks
;
}
CK_TILE_HOST_DEVICE
constexpr
index_t
get_num_of_access
()
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
return
u
.
get_num_of_access
()
*
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
}
template
<
typename
F
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
sweep_tile_uspan
(
spans
[
number
<
I
>
{}],
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
i_idx
)...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
);
},
get_y_unpacks
());
}
template
<
typename
F
,
index_t
i_access
>
CK_TILE_HOST_DEVICE
constexpr
void
operator
()(
const
F
&
f
,
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
u
=
static_uford
<
typename
decltype
(
spans
[
number
<
I
>
{}])
::
Impl
,
decltype
(
get_y_unpacks
())
>
{};
constexpr
auto
access_stride
=
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}
.
get_num_of_access
();
constexpr
auto
curr_i_access
=
number
<
i_access
/
access_stride
>
{};
constexpr
auto
next_i_access
=
number
<
i_access
%
access_stride
>
{};
u
(
[
&
](
auto
...
i_idx
)
{
constexpr
auto
next_span_idx
=
make_tuple
(
make_tuple
(
detail
::
make_tile_distributed_index
(
i_idx
))...);
sweep_tile_impl
<
DistributedTensor
,
UnpacksPerXDim
,
sequence
<
Is
...
>>
{}(
f
,
next_span_idx
,
next_i_access
);
},
curr_i_access
);
}
};
}
// namespace impl
/*
* Enhanced sweep-tile utility, can control unpacks along each X-dim
* the lambda function argument is the distributed-idx, which can directly
* plugged into the distributed tensor as setter/getter
*
* e.g. below function, y with the type DistributedTensor, r is row scale
*
* // sweep tile 1 by 1
* sweep_tile<DistributedTensor>([&](auto idx) {
* constexpr auto row_id = make_tuple(idx[number<0>{}]);
* y(idx) = y(idx) * r(row_id);
* });
*
* // sweep tile with 2 pixel from last dim each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_0, auto idx_1) {
* constexpr auto row_id = make_tuple(idx_0[number<0>{}]);
* y(idx_0) = y(idx_0) * r(row_id);
* y(idx_1) = y(idx_1) * r(row_id);
* },
* sequence<1, 2>{});
*
* // sweep tile with 2x2 pixel each function call
* sweep_tile<DistributedTensor>(
* [&](auto idx_00, auto idx_01, auto idx_10, auto idx_11) {
* constexpr auto row_id0 = make_tuple(idx_00[number<0>{}]);
* constexpr auto row_id1 = make_tuple(idx_10[number<0>{}]);
* y(idx_00) = y(idx_00) * r(row_id0);
* y(idx_01) = y(idx_01) * r(row_id0);
* y(idx_10) = y(idx_10) * r(row_id1);
* y(idx_11) = y(idx_11) * r(row_id1);
* },
* sequence<2, 2>{});
*
* TODO: do we need constexpr? lambda function could be non-constexpr
*/
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
F
&
f
,
UnpacksPerXDim
=
{})
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
);
}
template
<
typename
DistributedTensor
,
typename
F
,
typename
UnpacksPerXDim
=
typename
uniform_sequence_gen
<
DistributedTensor
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE
constexpr
void
sweep_tile
(
const
DistributedTensor
&
,
const
F
&
f
,
UnpacksPerXDim
=
{})
{
sweep_tile
<
DistributedTensor
,
F
,
UnpacksPerXDim
>
(
f
,
UnpacksPerXDim
{});
}
/*
* construct a sweep tile instance, which support issue the lambda one by one
* Note that this struct will hold the lambda functor, but will not hold the distributed tensor
* the functionality is the same as sweep_tile()
*/
template
<
typename
DistributedTensor_
,
typename
F_
,
typename
UnpacksPerXDim_
=
typename
uniform_sequence_gen
<
DistributedTensor_
::
get_num_of_dimension
(),
1
>
::
type
>
struct
SweepTile
{
using
DistributedTensor
=
remove_cvref_t
<
DistributedTensor_
>
;
using
F
=
remove_cvref_t
<
F_
>
;
using
UnpacksPerXDim
=
remove_cvref_t
<
UnpacksPerXDim_
>
;
CK_TILE_HOST_DEVICE
SweepTile
(
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{}
CK_TILE_HOST_DEVICE
SweepTile
(
const
DistributedTensor
&
,
const
F
&
f_
,
UnpacksPerXDim
=
{})
:
f
(
f_
)
{
}
CK_TILE_HOST_DEVICE
static
constexpr
index_t
get_num_of_access
()
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
tmp
=
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{};
return
tmp
.
get_num_of_access
();
}
CK_TILE_HOST_DEVICE
void
operator
()()
const
{
sweep_tile
<
DistributedTensor
>
(
f
,
UnpacksPerXDim
{});
}
template
<
index_t
i_access
>
CK_TILE_HOST_DEVICE
void
operator
()(
number
<
i_access
>
)
const
{
constexpr
auto
spans
=
DistributedTensor
::
get_distributed_spans
();
impl
::
sweep_tile_impl_0
<
DistributedTensor
,
UnpacksPerXDim
,
typename
arithmetic_sequence_gen
<
0
,
spans
.
size
(),
1
>::
type
>
{}(
f
,
number
<
i_access
>
{});
}
F
f
;
};
// partial deduction is not allowed
// template <typename T, typename F, typename U>
// CK_TILE_HOST_DEVICE_EXTERN SweepTile(const F&, U = {})->SweepTile<T, F, U>;
// deduction guide
template
<
typename
T
,
typename
F
,
typename
U
=
typename
uniform_sequence_gen
<
T
::
get_num_of_dimension
(),
1
>
::
type
>
CK_TILE_HOST_DEVICE_EXTERN
SweepTile
(
const
T
&
,
const
F
&
,
U
=
{})
->
SweepTile
<
T
,
F
,
U
>
;
}
// namespace ck_tile
}
// namespace ck_tile
include/ck_tile/core/utility/functional_with_tuple.hpp
View file @
e1396d87
...
@@ -105,6 +105,8 @@ struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
...
@@ -105,6 +105,8 @@ struct static_uford_one_shot_impl<sequence<>, sequence<>, Orders>
}
// namespace detail
}
// namespace detail
// TODO: we may unify static_ford/static_uford in the future
//
// loop over nd space(sequence) with packs
// loop over nd space(sequence) with packs
// you must make sure the function passed in has same number of argument
// you must make sure the function passed in has same number of argument
//
//
...
...
include/ck_tile/ops/fmha.hpp
View file @
e1396d87
...
@@ -33,8 +33,8 @@
...
@@ -33,8 +33,8 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
//
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
//
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_async_ex_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qr_ks_vs_async_default_policy.hpp"
...
...
include/ck_tile/ops/softmax/block/block_softmax_2d.hpp
View file @
e1396d87
...
@@ -46,15 +46,9 @@ struct BlockSoftmax2D
...
@@ -46,15 +46,9 @@ struct BlockSoftmax2D
#else
#else
auto
row_max
=
reduce_row_max
(
f_max
);
auto
row_max
=
reduce_row_max
(
f_max
);
#endif
#endif
// compute elementwise softmax
sweep_tile
<
DistributedTensor
>
([
&
](
auto
idx
)
{
constexpr
auto
span_2d
=
DistributedTensor
::
get_distributed_spans
();
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
y
(
idx
)
=
exp
(
x
[
idx
]
-
row_max
[
row_id
]);
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
y
(
i_j_idx
)
=
exp
(
x
[
i_j_idx
]
-
row_max
(
i_idx
));
});
});
});
// compute row sum
// compute row sum
...
@@ -66,19 +60,12 @@ struct BlockSoftmax2D
...
@@ -66,19 +60,12 @@ struct BlockSoftmax2D
#endif
#endif
// reciprocal
// reciprocal
auto
r
=
make_static_distributed_tensor
<
DataType
>
(
row_sum
.
get_tile_distribution
());
auto
r
=
make_static_distributed_tensor
<
DataType
>
(
row_sum
.
get_tile_distribution
());
constexpr
auto
span_1d
=
decltype
(
r
)
::
get_distributed_spans
();
sweep_tile
(
row_sum
,
[
&
](
auto
idx
)
{
r
(
idx
)
=
DataType
{
1
}
/
row_sum
(
idx
);
});
sweep_tile_span
(
span_1d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
r
(
i_idx
)
=
DataType
{
1
}
/
row_sum
(
i_idx
);
});
// scale
// scale
sweep_tile_span
(
span_2d
[
number
<
0
>
{}],
[
&
](
auto
idx0
)
{
sweep_tile
<
DistributedTensor
>
([
&
](
auto
idx
)
{
constexpr
auto
i_idx
=
make_tuple
(
idx0
);
constexpr
auto
row_id
=
make_tuple
(
idx
[
number
<
0
>
{}]);
sweep_tile_span
(
span_2d
[
number
<
1
>
{}],
[
&
](
auto
idx1
)
{
y
(
idx
)
=
y
(
idx
)
*
r
(
row_id
);
constexpr
auto
i_j_idx
=
make_tuple
(
idx0
,
idx1
);
y
(
i_j_idx
)
=
y
(
i_j_idx
)
*
r
(
i_idx
);
});
});
});
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment