Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
1d6022b1
Commit
1d6022b1
authored
Dec 16, 2020
by
Jing Zhang
Browse files
vector type output
parent
9a54fbd8
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
56 additions
and
21 deletions
+56
-21
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
...lude/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
+17
-17
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
...sor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
+32
-0
composable_kernel/include/utility/float_type.amd.hpp.in
composable_kernel/include/utility/float_type.amd.hpp.in
+7
-4
No files found.
composable_kernel/include/tensor_operation/gridwise_gemm_xdlops_fp16_bfp16.hpp
View file @
1d6022b1
...
@@ -295,13 +295,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
...
@@ -295,13 +295,11 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
constexpr
index_t
BlkSize
=
blockwise_gemm
.
GetBlkSize
();
constexpr
index_t
BlkSize
=
blockwise_gemm
.
GetBlkSize
();
constexpr
index_t
NumBlks
=
blockwise_gemm
.
GetNumBlks
();
constexpr
index_t
NumBlks
=
blockwise_gemm
.
GetNumBlks
();
// force unrolling the output loop to get ride of scratches
// force unrolling the output loop to get ride of scratches
#pragma unroll
static_for
<
0
,
NumBlks
,
1
>
{}([
&
](
auto
blk_id
)
{
for
(
index_t
i
=
0
;
i
<
NumBlks
;
++
i
)
{
// calculate origin of thread output tensor on global memory
// calculate origin of thread output tensor on global memory
// blockwise GEMM c matrix starting index
// blockwise GEMM c matrix starting index
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
i
);
const
auto
c_thread_mtx_on_block
=
blockwise_gemm
.
GetBeginOfThreadMatrixC
(
blk_id
);
const
index_t
m_thread_data_on_global
=
const
index_t
m_thread_data_on_global
=
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
m_block_data_on_global
+
c_thread_mtx_on_block
.
row
;
...
@@ -309,24 +307,26 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
...
@@ -309,24 +307,26 @@ struct GridwiseBatchGemmXdlops_gkmkpack_gknkpack_gmn_v2
const
index_t
n_thread_data_on_global
=
const
index_t
n_thread_data_on_global
=
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
n_block_data_on_global
+
c_thread_mtx_on_block
.
col
;
ThreadwiseGenericTensorSliceCopy_v4r2
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
ThreadwiseGenericTensorSliceCopy_v5
<
decltype
(
c_g_m0_m1_m2_n_thread_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
decltype
(
c_g_m0_m1_m2_n_global_desc
),
CThreadCopySliceLengths
,
CThreadCopySliceLengths
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
4
,
arithmetic_sequence_gen
<
0
,
5
,
1
>::
type
,
1
,
4
,
1
,
4
,
AddressSpace
::
Vgpr
,
1
,
AddressSpace
::
Global
,
1
,
CGlobalMemoryOp
>
(
AddressSpace
::
Vgpr
,
AddressSpace
::
Global
,
CGlobalMemoryOp
>
(
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
0
,
0
,
0
,
0
,
0
),
make_multi_index
(
g_block_data_on_global
,
make_multi_index
(
g_block_data_on_global
,
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
/
(
M2
*
M1
),
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
(
M2
*
M1
)
/
M2
,
m_thread_data_on_global
%
M2
,
m_thread_data_on_global
%
M2
,
n_thread_data_on_global
))
n_thread_data_on_global
))
.
Run
(
c_thread_vec
.
n
+
i
*
BlkSize
,
p_c_global
);
.
Run
(
c_thread_vec
.
At
(
Number
<
16
>
{})[
Number
<
blk_id
>
{}]
,
p_c_global
);
}
}
);
}
}
}
}
};
};
...
...
composable_kernel/include/tensor_operation/threadwise_generic_tensor_slice_copy_v2.hpp
View file @
1d6022b1
...
@@ -236,6 +236,38 @@ struct ThreadwiseGenericTensorSliceCopy_v5
...
@@ -236,6 +236,38 @@ struct ThreadwiseGenericTensorSliceCopy_v5
});
});
}
}
template
<
typename
SrcData
,
typename
DstData
>
__device__
void
Run
(
SrcData
src
,
DstData
*
p_dst
)
{
constexpr
auto
vector_access_dim
=
Number
<
DstVectorWriteDim
>
{};
constexpr
auto
dst_data_per_access
=
Number
<
DstDataPerWrite
>
{};
static_assert
(
DstDataPerWrite
==
1
||
DstDataPerWrite
==
2
||
DstDataPerWrite
==
4
,
""
);
constexpr
auto
long_vector_size
=
dst_data_per_access
;
constexpr
auto
long_vector_access_lengths
=
SliceLengths
::
Modify
(
vector_access_dim
,
SliceLengths
::
Get
(
vector_access_dim
)
/
long_vector_size
);
static_ford
<
decltype
(
long_vector_access_lengths
),
DstDimAccessOrder
>
{}(
[
&
](
auto
long_vector_access_id
)
{
constexpr
auto
long_vector_data_begin_id
=
long_vector_access_id
.
Modify
(
Number
<
vector_access_dim
>
{},
Number
<
long_vector_size
*
long_vector_access_id
[
vector_access_dim
]
>
{});
constexpr
auto
buff_off
=
ThreadBufferDesc
::
CalculateOffset
(
to_multi_index
(
long_vector_data_begin_id
))
/
long_vector_size
;
auto
src_buff
=
src
.
At
(
Number
<
DstDataPerWrite
>
{})[
Number
<
buff_off
>
{}];
const
auto
dst_coord
=
mDstSliceOrigin
+
to_multi_index
(
long_vector_data_begin_id
);
vector_data_store
<
DstData
,
DstDataPerWrite
>::
run
(
p_dst
,
src_buff
,
dst_coord
);
});
}
template
<
typename
T
,
bool
PositiveDirection
>
template
<
typename
T
,
bool
PositiveDirection
>
__device__
void
MoveSrcSliceWindow
(
const
T
&
step_sizes_
,
__device__
void
MoveSrcSliceWindow
(
const
T
&
step_sizes_
,
integral_constant
<
bool
,
PositiveDirection
>
)
integral_constant
<
bool
,
PositiveDirection
>
)
...
...
composable_kernel/include/utility/float_type.amd.hpp.in
View file @
1d6022b1
...
@@ -188,7 +188,6 @@ union float_vec64_t
...
@@ -188,7 +188,6 @@ union float_vec64_t
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float32_t, 2> s32;
StaticallyIndexedArray<float64_t, 1> s64;
StaticallyIndexedArray<float64_t, 1> s64;
float n[64];
__host__ __device__ constexpr float_vec64_t() {}
__host__ __device__ constexpr float_vec64_t() {}
template <index_t vs>
template <index_t vs>
...
@@ -210,10 +209,10 @@ union float_vec64_t
...
@@ -210,10 +209,10 @@ union float_vec64_t
union float_vec128_t
union float_vec128_t
{
{
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float, 64> s1;
StaticallyIndexedArray<float_vec16_t, 8> s16;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float32_t, 4> s32;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float_vec64_t, 2> s64;
StaticallyIndexedArray<float128_t, 1> s128;
StaticallyIndexedArray<float128_t, 1> s128;
float n[128];
__host__ __device__ constexpr float_vec128_t() {}
__host__ __device__ constexpr float_vec128_t() {}
template <index_t vs>
template <index_t vs>
...
@@ -225,6 +224,12 @@ union float_vec128_t
...
@@ -225,6 +224,12 @@ union float_vec128_t
return s1;
return s1;
}
}
template <>
__host__ __device__ auto& At(Number<16>)
{
return s16;
}
template <>
template <>
__host__ __device__ auto& At(Number<32>)
__host__ __device__ auto& At(Number<32>)
{
{
...
@@ -238,8 +243,6 @@ union float_vec128_t
...
@@ -238,8 +243,6 @@ union float_vec128_t
}
}
};
};
template <typename T, index_t BufferSize>
template <typename T, index_t BufferSize>
constexpr auto GetRegBuffer();
constexpr auto GetRegBuffer();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment