Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1a43a538
Commit
1a43a538
authored
Aug 26, 2021
by
Chao Liu
Browse files
add Merg_v3_division_mod
parent
a7a758d8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
140 additions
and
3 deletions
+140
-3
composable_kernel/include/tensor_description/multi_index_transform.hpp
...rnel/include/tensor_description/multi_index_transform.hpp
+123
-0
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
...clude/tensor_description/multi_index_transform_helper.hpp
+17
-3
No files found.
composable_kernel/include/tensor_description/multi_index_transform.hpp
View file @
1a43a538
...
@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division
...
@@ -1327,6 +1327,129 @@ struct Merge_v2r2_magic_division
}
}
};
};
// Implementation of "Merge" transformation primitive that uses division and mod. It is supposed to
// be used for low_lengths that are known at compile time and are power of 2, otherwise performance
// will be very bad
template
<
typename
LowLengths
>
struct
Merge_v3_division_mod
{
static
constexpr
index_t
NDimLow
=
LowLengths
::
Size
();
using
LowerIndex
=
MultiIndex
<
NDimLow
>
;
using
UpperIndex
=
MultiIndex
<
1
>
;
using
LowLengthsScan
=
decltype
(
container_reverse_exclusive_scan
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{}));
using
UpLengths
=
decltype
(
make_tuple
(
container_reduce
(
LowLengths
{},
math
::
multiplies
{},
Number
<
1
>
{})));
LowLengths
low_lengths_
;
LowLengthsScan
low_lengths_scan_
;
UpLengths
up_lengths_
;
__host__
__device__
constexpr
Merge_v3_division_mod
()
=
default
;
__host__
__device__
constexpr
Merge_v3_division_mod
(
const
LowLengths
&
low_lengths
)
:
low_lengths_
{
low_lengths
},
low_lengths_scan_
{
container_reverse_exclusive_scan
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{})},
up_lengths_
{
make_tuple
(
container_reduce
(
low_lengths
,
math
::
multiplies
{},
Number
<
1
>
{}))}
{
static_assert
(
LowerIndex
::
Size
()
==
NDimLow
,
"wrong!"
);
}
__host__
__device__
static
constexpr
index_t
GetNumOfLowerDimension
()
{
return
NDimLow
;
}
__host__
__device__
static
constexpr
index_t
GetNumOfUpperDimension
()
{
return
1
;
}
__host__
__device__
constexpr
const
auto
&
GetUpperLengths
()
const
{
return
up_lengths_
;
}
template
<
typename
LowIdx
,
typename
UpIdx
>
__host__
__device__
constexpr
void
CalculateLowerIndex
(
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up
)
const
{
static_assert
(
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
index_t
tmp
=
idx_up
[
Number
<
0
>
{}];
// division and mod
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
idx_low
(
Number
<
NDimLow
-
1
>
{})
=
tmp
;
}
template
<
typename
LowIdxDiff
,
typename
UpIdxDiff
,
typename
LowIdx
,
typename
UpIdx
,
index_t
Hack
>
__host__
__device__
void
UpdateLowerIndex
(
LowIdxDiff
&
idx_diff_low
,
const
UpIdxDiff
&
,
LowIdx
&
idx_low
,
const
UpIdx
&
idx_up_new
,
Number
<
Hack
>
)
const
{
static_assert
(
LowIdxDiff
::
Size
()
==
NDimLow
&&
UpIdxDiff
::
Size
()
==
1
&&
LowIdx
::
Size
()
==
NDimLow
&&
UpIdx
::
Size
()
==
1
,
"wrong! inconsistent # of dimension"
);
constexpr
auto
I0
=
Number
<
0
>
{};
constexpr
auto
INm1
=
Number
<
NDimLow
-
1
>
{};
index_t
tmp
=
idx_up_new
[
I0
];
static_for
<
0
,
NDimLow
-
1
,
1
>
{}([
&
](
auto
i
)
{
const
index_t
tmp2
=
idx_low
[
i
];
idx_low
(
i
)
=
tmp
/
this
->
low_lengths_scan_
[
i
];
idx_diff_low
(
i
)
=
idx_low
[
i
]
-
tmp2
;
tmp
%=
this
->
low_lengths_scan_
[
i
];
});
const
index_t
tmp2
=
idx_low
[
INm1
];
idx_low
(
INm1
)
=
tmp
;
idx_diff_low
(
INm1
)
=
idx_low
[
INm1
]
-
tmp2
;
}
__host__
__device__
static
constexpr
bool
IsLinearTransform
()
{
return
false
;
}
__host__
__device__
static
constexpr
bool
IsValidUpperIndexAlwaysMappedToValidLowerIndex
()
{
return
true
;
}
__host__
__device__
static
constexpr
bool
IsKnownAtCompileTime
()
{
return
is_known_at_compile_time
<
LowLengths
>::
value
&&
is_known_at_compile_time
<
LowLengthsScan
>::
value
&&
is_known_at_compile_time
<
UpLengths
>::
value
;
}
template
<
typename
UpIdx
>
__host__
__device__
static
constexpr
bool
IsValidUpperIndexMappedToValidLowerIndex
(
const
UpIdx
&
/* idx_up */
)
{
return
true
;
}
__host__
__device__
void
Print
()
const
{
printf
(
"{"
);
printf
(
"Merge_v3_direct_division_mod, "
);
printf
(
"low_lengths_ "
);
print_multi_index
(
low_lengths_
);
printf
(
"low_lengths_scan_ "
);
print_multi_index
(
low_lengths_scan_
);
printf
(
"up_lengths_ "
);
print_multi_index
(
up_lengths_
);
printf
(
"}"
);
}
};
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
>
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
>
struct
UnMerge
struct
UnMerge
{
{
...
...
composable_kernel/include/tensor_description/multi_index_transform_helper.hpp
View file @
1a43a538
...
@@ -52,17 +52,24 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
...
@@ -52,17 +52,24 @@ __host__ __device__ constexpr auto make_embed_transform(const UpLengths& up_leng
template
<
typename
LowLengths
>
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
__host__
__device__
constexpr
auto
make_merge_transform
(
const
LowLengths
&
low_lengths
)
{
{
#if !CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
#if CK_EXPERIMENTAL_MERGE_USE_MAGIC_DIVISION
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
#else
#if 1
#if 1
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
#else
#else
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2r2_magic_division
<
LowLengths
>
{
low_lengths
};
#endif
#endif
#else
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
#endif
#endif
}
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v1_carry_check
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v1_carry_check
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
LowLengths
>
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
__host__
__device__
constexpr
auto
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
make_merge_transform_v2_magic_division
(
const
LowLengths
&
low_lengths
)
...
@@ -70,6 +77,13 @@ make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
...
@@ -70,6 +77,13 @@ make_merge_transform_v2_magic_division(const LowLengths& low_lengths)
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
return
Merge_v2_magic_division
<
LowLengths
>
{
low_lengths
};
}
}
template
<
typename
LowLengths
>
__host__
__device__
constexpr
auto
make_merge_transform_v3_division_mod
(
const
LowLengths
&
low_lengths
)
{
return
Merge_v3_division_mod
<
LowLengths
>
{
low_lengths
};
}
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
template
<
typename
UpLengths
,
bool
Use24BitIntegerCalculation
=
false
>
__host__
__device__
constexpr
auto
make_unmerge_transform
(
__host__
__device__
constexpr
auto
make_unmerge_transform
(
const
UpLengths
&
up_lengths
,
const
UpLengths
&
up_lengths
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment