Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
f3acd251
Unverified
Commit
f3acd251
authored
Sep 05, 2021
by
Chao Liu
Committed by
GitHub
Sep 05, 2021
Browse files
Add a version of Merge transform that use integerdivision and mod (#25)
* add Merg_v3_division_mod * refactor
parent
19613902
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
142 additions
and
5 deletions
+142
-5
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+123
-0
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
...clude/tensor_description/multi_index_transform_helper.hpp
+19
-5
No files found.
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
f3acd251
...
...
@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division
}
};
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v3_division_mod
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v3_division_mod
()
=
default
;
__host__
__device__
constexpr
Merge_v3_division_mod
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
tmp2
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
tmp2
;
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
const
index_t
tmp2
=
idx_low
[
INm1
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_low
[
INm1
]
-
tmp2
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
>
struct
UnMerge
{
...
...
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
View file @
f3acd251
...
...
@@ -52,22 +52,36 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
#if
!
CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
}
;
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
make_merge_transform_v2_magic_division
(
low_lengths
)
;
#else
return
make_merge_transform_v1_carry_check
(
low_lengths
);
#endif
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v1_carry_check
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
{
#if 1
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
#else
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
#endif
#endif
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v
2_magic
_division
(
const
LowLengths
&
low_lengths
)
make_merge_transform_v
3
_division
_mod
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v
2_magic
_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v
3
_division
_mod
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment