Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel_ROCM
Commits
65cfb2a1
Commit
65cfb2a1
authored
Oct 21, 2024
by
Jing Zhang
Browse files
format
parent
398f8851
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
25 additions
and
24 deletions
+25
-24
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
...or_operation/gpu/element/unary_element_wise_operation.hpp
+11
-8
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
...nsor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
+0
-1
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+4
-4
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
...tion/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
+8
-9
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
...library/reference_tensor_operation/cpu/reference_gemm.hpp
+2
-2
No files found.
include/ck/tensor_operation/gpu/element/unary_element_wise_operation.hpp
View file @
65cfb2a1
...
...
@@ -55,8 +55,8 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
#else
uint8_t
x_u8
=
ck
::
bit_cast
<
uint8_t
>
(
q
);
int
x_l
=
(
x_u8
&
0x0f
);
int
x_h
=
(
x_u8
&
0xf0
)
<<
12
;
int
x_l
=
(
x_u8
&
0x0f
);
int
x_h
=
(
x_u8
&
0xf0
)
<<
12
;
const
int
EX
=
0x64006400
;
...
...
@@ -66,7 +66,6 @@ __device__ inline half2_t pki4_to_half2(pk_i4_t q)
return
amd_assembly_pk_add_f16
(
bit_cast
<
half2_t
>
(
lo
),
bit_cast
<
half2_t
>
(
SUB
));
#endif
}
struct
PassThroughPack8
...
...
@@ -87,12 +86,16 @@ struct PassThroughPack8
vector_type
<
half_t
,
8
>
dst
;
vector_type
<
pk_i4_t
,
4
>
src
{
x
};
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
0
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
0
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
1
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
1
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
2
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
2
>
{}]);
dst
.
template
AsType
<
half2_t
>()(
Number
<
3
>
{})
=
pki4_to_half2
(
src
.
template
AsType
<
pk_i4_t
>()[
Number
<
3
>
{}]);
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
y
=
dst
.
template
AsType
<
half8_t
>()[
Number
<
0
>
{}];
#endif
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_xdl_cshuffle_v3.hpp
View file @
65cfb2a1
...
...
@@ -1370,7 +1370,6 @@ struct GridwiseGemm_xdl_cshuffle_v3
c_thread_buf
,
num_k_block_main_loop
);
// shuffle C and write out
{
static_assert
(
MXdlPerWave
%
CShuffleMXdlPerWavePerShuffle
==
0
&&
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
65cfb2a1
...
...
@@ -1025,8 +1025,7 @@ struct ThreadwiseTensorSliceTransfer_v4
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
SrcScalarPerVector
%
PackedSize
==
0
,
"pk data N cannot be 1"
);
static_assert
(
SrcScalarPerVector
%
PackedSize
==
0
,
"pk data N cannot be 1"
);
}
}
...
...
@@ -1126,8 +1125,9 @@ struct ThreadwiseTensorSliceTransfer_v4
using
src_vector_t
=
typename
decltype
(
src_tmp_vector
)
::
type
;
//const bool is_src_valid = coordinate_has_valid_offset_assuming_visible_index_is_valid(
//src_desc, src_data_coord);
// const bool is_src_valid =
// coordinate_has_valid_offset_assuming_visible_index_is_valid( src_desc,
// src_data_coord);
const
bool
is_src_valid
=
true
;
// copy data from src_buf into src_tmp_vector
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer_v3r1.hpp
View file @
65cfb2a1
...
...
@@ -80,14 +80,13 @@ struct ThreadwiseTensorSliceTransfer_v3r1
if
constexpr
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
pk_i4_t
>
)
{
static_assert
(
is_same_v
<
remove_cvref_t
<
SrcData
>
,
remove_cvref_t
<
DstData
>>
,
"SrcData != DstData"
);
"SrcData != DstData"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"
);
static_assert
(
SrcScalarPerVector_
%
PackedSize
==
0
&&
DstScalarPerVector_
%
PackedSize
==
0
,
"SrcScalarPerVector_ and DstScalarPerVector_ cannot be 1"
);
static_assert
(
SrcVectorDim
==
DstVectorDim
,
"pk_i4_t does not support transpose"
);
static_assert
(
SrcVectorDim
==
DstVectorDim
,
"pk_i4_t does not support transpose"
);
}
}
...
...
@@ -446,7 +445,7 @@ struct ThreadwiseTensorSliceTransfer_v3r1
else
{
constexpr
auto
packed_per_access
=
generate_sequence
(
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
PackedSize
>
{},
Number
<
nDim
>
{});
detail
::
lambda_scalar_per_access
<
SrcVectorDim
,
PackedSize
>
{},
Number
<
nDim
>
{});
constexpr
auto
packed_access_lengths
=
SliceLengths
{}
/
packed_per_access
;
...
...
@@ -875,8 +874,8 @@ struct ThreadwiseTensorSliceTransfer_v3r1
private:
static
constexpr
auto
src_thread_scratch_desc_
=
decltype
(
GetSrcThreadScratchDescriptor
()){};
//static constexpr auto src_oob_thread_scratch_desc_ =
//decltype(GetSrcThreadScratchDescriptor()){};
//
static constexpr auto src_oob_thread_scratch_desc_ =
//
decltype(GetSrcThreadScratchDescriptor()){};
static
constexpr
auto
dst_thread_scratch_desc_
=
decltype
(
GetDstThreadScratchDescriptor
()){};
using
SrcThreadScratch
=
...
...
library/include/ck/library/reference_tensor_operation/cpu/reference_gemm.hpp
View file @
65cfb2a1
...
...
@@ -82,7 +82,7 @@ struct ReferenceGemm : public device::BaseOperator
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
i4
=
i4
-
8
;
v_a
=
type_convert
<
ComputeTypeA
>
(
i4
);
}
else
...
...
@@ -103,7 +103,7 @@ struct ReferenceGemm : public device::BaseOperator
i4
=
(
i4x2
>>
0
)
&
0xf
;
else
i4
=
(
i4x2
>>
4
)
&
0xf
;
i4
=
i4
-
8
;
i4
=
i4
-
8
;
v_b
=
type_convert
<
ComputeTypeB
>
(
i4
);
}
else
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment