Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
composable_kernel
Commits
27dc055b
Commit
27dc055b
authored
Feb 16, 2023
by
aska-0096
Browse files
fix a host tensor bug and clean up flash-attn code
parent
4ddda63b
Changes
8
Show whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
83 additions
and
54 deletions
+83
-54
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
..._bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
+38
-22
include/ck/host_utility/kernel_launch.hpp
include/ck/host_utility/kernel_launch.hpp
+3
-3
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
...l/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
+8
-0
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
...grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
+4
-1
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
...ation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
+7
-8
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
...operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
+22
-17
include/ck/utility/data_type.hpp
include/ck/utility/data_type.hpp
+0
-2
library/include/ck/library/utility/host_tensor.hpp
library/include/ck/library/utility/host_tensor.hpp
+1
-1
No files found.
example/29_batched_gemm_bias_e_permute/batched_gemm_bias_e_permute_wmma_fp16.cpp
View file @
27dc055b
...
@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
...
@@ -43,9 +43,10 @@ using AElementOp = ck::tensor_operation::element_wise::PassThrough;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
BElementOp
=
ck
::
tensor_operation
::
element_wise
::
PassThrough
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
using
CDEElementOp
=
ck
::
tensor_operation
::
element_wise
::
Add
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
Default
;
static
constexpr
auto
GemmSpec
=
ck
::
tensor_operation
::
device
::
GemmSpecialization
::
MNKPadding
;
static
constexpr
auto
ABSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Packed
;
static
constexpr
auto
ASpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
BSpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
static
constexpr
auto
DESpec
=
ck
::
tensor_operation
::
device
::
TensorSpecialization
::
Default
;
using
DeviceOpInstanceKKNN
=
using
DeviceOpInstanceKKNN
=
...
@@ -64,18 +65,18 @@ using DeviceOpInstanceKKNN =
...
@@ -64,18 +65,18 @@ using DeviceOpInstanceKKNN =
BElementOp
,
BElementOp
,
CDEElementOp
,
CDEElementOp
,
GemmSpec
,
GemmSpec
,
A
B
Spec
,
ASpec
,
A
BSpec
,
BSpec
,
DESpec
,
DESpec
,
256
,
256
,
128
,
128
,
256
,
128
,
8
,
4
,
8
,
8
,
16
,
16
,
16
,
16
,
4
,
4
,
4
,
2
,
S
<
4
,
64
,
1
>
,
S
<
4
,
64
,
1
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
S
<
1
,
0
,
2
>
,
...
@@ -252,21 +253,6 @@ int main(int argc, char* argv[])
...
@@ -252,21 +253,6 @@ int main(int argc, char* argv[])
ck
::
index_t
K0
=
2048
;
ck
::
index_t
K0
=
2048
;
// A[G0, G1, M0, M1, K0]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
G1
*
M0
*
M1
*
K0
,
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[G0, G1, N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_lengths
{
G0
,
G1
,
N0
,
N1
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_strides
{
G1
*
N0
*
N1
*
K0
,
N0
*
N1
*
K0
,
N1
*
K0
,
K0
,
1
};
// D[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
{
G1
*
N0
*
N1
,
N0
*
N1
,
0
,
0
,
N1
,
1
};
// E[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
G1
*
M0
*
N0
*
M1
*
N1
,
M0
*
N0
*
M1
*
N1
,
N0
*
M1
*
N1
,
N1
,
M1
*
N1
,
1
};
if
(
argc
==
1
)
if
(
argc
==
1
)
{
{
// use default case
// use default case
...
@@ -277,13 +263,43 @@ int main(int argc, char* argv[])
...
@@ -277,13 +263,43 @@ int main(int argc, char* argv[])
init_method
=
std
::
stoi
(
argv
[
2
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
}
}
else
if
(
argc
==
11
)
{
do_verification
=
std
::
stoi
(
argv
[
1
]);
init_method
=
std
::
stoi
(
argv
[
2
]);
time_kernel
=
std
::
stoi
(
argv
[
3
]);
G0
=
std
::
stoi
(
argv
[
4
]);
G1
=
std
::
stoi
(
argv
[
5
]);
M0
=
std
::
stoi
(
argv
[
6
]);
M1
=
std
::
stoi
(
argv
[
7
]);
N0
=
std
::
stoi
(
argv
[
8
]);
N1
=
std
::
stoi
(
argv
[
9
]);
K0
=
std
::
stoi
(
argv
[
10
]);
}
else
else
{
{
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg1: verification (0=no, 1=yes)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg2: initialization (0=no init, 1=integer value, 2=decimal value)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg3: time kernel (0=no, 1=yes)
\n
"
);
printf
(
"arg4-10: G0, G1, M0, M1, N0, N1, K0
\n
"
);
exit
(
0
);
exit
(
0
);
}
}
// A[G0, G1, M0, M1, K0]
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_lengths
{
G0
,
G1
,
M0
,
M1
,
K0
};
std
::
vector
<
ck
::
index_t
>
a_gs_ms_ks_strides
{
G1
*
M0
*
M1
*
K0
,
M0
*
M1
*
K0
,
M1
*
K0
,
K0
,
1
};
// B[G0, G1, N0, N1, K0]
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_lengths
{
G0
,
G1
,
N0
,
N1
,
K0
};
std
::
vector
<
ck
::
index_t
>
b_gs_ns_ks_strides
{
G1
*
N0
*
N1
*
K0
,
N0
*
N1
*
K0
,
N1
*
K0
,
K0
,
1
};
// D[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
d_gs_ms_ns_strides
{
G1
*
N0
*
N1
,
N0
*
N1
,
0
,
0
,
N1
,
1
};
// E[G0, G1, M0, N0, M1, N1]
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_lengths
{
G0
,
G1
,
M0
,
M1
,
N0
,
N1
};
std
::
vector
<
ck
::
index_t
>
e_gs_ms_ns_strides
{
G1
*
M0
*
N0
*
M1
*
N1
,
M0
*
N0
*
M1
*
N1
,
N0
*
M1
*
N1
,
N1
,
M1
*
N1
,
1
};
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
ADataType
>
a_gs_ms_ks
(
a_gs_ms_ks_lengths
,
a_gs_ms_ks_strides
);
Tensor
<
BDataType
>
b_gs_ns_ks
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
Tensor
<
BDataType
>
b_gs_ns_ks
(
b_gs_ns_ks_lengths
,
b_gs_ns_ks_strides
);
Tensor
<
DDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
Tensor
<
DDataType
>
d_gs_ms_ns
(
d_gs_ms_ns_lengths
,
d_gs_ms_ns_strides
);
...
...
include/ck/host_utility/kernel_launch.hpp
View file @
27dc055b
...
@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
...
@@ -29,12 +29,12 @@ float launch_and_time_kernel(const StreamConfig& stream_config,
block_dim
.
y
,
block_dim
.
y
,
block_dim
.
z
);
block_dim
.
z
);
const
int
nrepeat
=
1
;
const
int
nrepeat
=
1
00
;
//
printf("Warm up 1 time\n");
printf
(
"Warm up 1 time
\n
"
);
// warm up
// warm up
//
kernel<<<grid_dim, block_dim, lds_byte, stream_config.stream_id_>>>(args...);
kernel
<<<
grid_dim
,
block_dim
,
lds_byte
,
stream_config
.
stream_id_
>>>
(
args
...);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
printf
(
"Start running %d times...
\n
"
,
nrepeat
);
...
...
include/ck/tensor_operation/gpu/device/impl/device_batched_contraction_multiple_d_wmma_cshuffle.hpp
View file @
27dc055b
...
@@ -771,6 +771,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -771,6 +771,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
{
{
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
if
constexpr
(
!
(
is_same_v
<
AccDataType
,
float
>
||
is_same_v
<
AccDataType
,
int32_t
>
))
{
{
printf
(
"DeviceOp: Arch check failure
\n
"
);
return
false
;
return
false
;
}
}
}
}
...
@@ -785,6 +786,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -785,6 +786,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
arg
.
e_grid_desc_m_n_
,
arg
.
e_grid_desc_m_n_
,
arg
.
block_2_ctile_map_
))
arg
.
block_2_ctile_map_
))
{
{
printf
(
"GridwiseOp: Validity check failure
\n
"
);
return
false
;
return
false
;
}
}
...
@@ -799,6 +801,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -799,6 +801,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
if
(
!
(
arg
.
a_mz_stride_
==
1
&&
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I1
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
printf
(
"DeviceOp: Vector Access A-m check failure
\n
"
);
return
false
;
return
false
;
}
}
}
}
...
@@ -807,6 +810,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -807,6 +810,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
if
(
!
(
arg
.
a_kz_stride_
==
1
&&
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
arg
.
a_grid_desc_k0_m_k1_
.
GetLength
(
I2
)
%
ABlockTransferSrcScalarPerVector
==
0
))
{
{
printf
(
"DeviceOp: Vector Access A-k check failure
\n
"
);
return
false
;
return
false
;
}
}
}
}
...
@@ -817,6 +821,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -817,6 +821,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
if
(
!
(
arg
.
b_nz_stride_
==
1
&&
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I1
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
printf
(
"DeviceOp: Vector Access B-n check failure
\n
"
);
return
false
;
return
false
;
}
}
}
}
...
@@ -825,6 +830,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -825,6 +830,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
if
(
!
(
arg
.
b_kz_stride_
==
1
&&
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
arg
.
b_grid_desc_k0_n_k1_
.
GetLength
(
I2
)
%
BBlockTransferSrcScalarPerVector
==
0
))
{
{
printf
(
"DeviceOp: Vector Access B-k check failure
\n
"
);
return
false
;
return
false
;
}
}
}
}
...
@@ -838,6 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -838,6 +844,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
0
))
0
))
{
{
printf
(
"DeviceOp: Vector Access D-n check failure
\n
"
);
valid_d_access
=
false
;
valid_d_access
=
false
;
}
}
});
});
...
@@ -854,6 +861,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
...
@@ -854,6 +861,7 @@ struct DeviceBatchedContractionMultipleD_Wmma_CShuffle
0
)
||
0
)
||
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
CDEShuffleBlockTransferScalarPerVector_NPerBlock
==
1
))
{
{
printf
(
"DeviceOp: Vector Access E-n check failure
\n
"
);
return
false
;
return
false
;
}
}
...
...
include/ck/tensor_operation/gpu/grid/gridwise_batched_gemm_softmax_gemm_wmma_cshuffle.hpp
View file @
27dc055b
...
@@ -352,6 +352,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -352,6 +352,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
K
=
a_grid_desc_ak0_m_ak1
.
GetLength
(
I0
)
*
a_grid_desc_ak0_m_ak1
.
GetLength
(
I2
);
const
auto
N
=
b1_grid_desc_l0_n_l1
.
GetLength
(
I1
);
const
auto
N
=
b1_grid_desc_l0_n_l1
.
GetLength
(
I1
);
printf
(
"M = %d, L = %d, K = %d, N = %d
\n
"
,
M
,
L
,
K
,
N
);
const
auto
KPerBlock
=
K0PerBlock
*
K1Value
;
const
auto
KPerBlock
=
K0PerBlock
*
K1Value
;
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
if
(
!
(
M
==
c_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
c_grid_desc_m_n
.
GetLength
(
I1
)))
{
{
...
@@ -730,7 +732,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
...
@@ -730,7 +732,8 @@ struct GridwiseBatchedGemmSoftmaxGemm_Wmma_CShuffle
// dst Rowlane
// dst Rowlane
// 0x76543210 0xfedcba98
// 0x76543210 0xfedcba98
// src Rowlane
// src Rowlane
0x76543210
,
0xfedcba98
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
0x76543210
,
0xfedcba98
,
false
>
{
tensor_operation
::
element_wise
::
PassThrough
{}};
// B1 matrix blockwise copy
// B1 matrix blockwise copy
auto
b1_blockwise_copy
=
auto
b1_blockwise_copy
=
...
...
include/ck/tensor_operation/gpu/grid/gridwise_gemm_multiple_d_wmma_cshuffle.hpp
View file @
27dc055b
...
@@ -148,14 +148,12 @@ __global__ void
...
@@ -148,14 +148,12 @@ __global__ void
const
Block2CTileMap
block_2_etile_map
)
const
Block2CTileMap
block_2_etile_map
)
{
{
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
#if(!defined(__HIP_DEVICE_COMPILE__) || defined(__gfx1100__))
//printf("entry kernel launch");
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
__shared__
char
p_shared
[
GridwiseOp
::
GetSharedMemoryNumberOfByte
()];
const
index_t
num_blocks_per_batch
=
const
index_t
num_blocks_per_batch
=
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
__builtin_amdgcn_readfirstlane
(
get_grid_size
()
/
batch_count
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
const
index_t
g_idx
=
__builtin_amdgcn_readfirstlane
(
get_block_1d_id
()
/
num_blocks_per_batch
);
//printf("before compute_ptr_offset call");
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
a_batch_offset
=
__builtin_amdgcn_readfirstlane
(
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
static_cast
<
long_index_t
>
(
compute_ptr_offset_of_batch
.
GetAPtrOffset
(
g_idx
)));
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
const
long_index_t
b_batch_offset
=
__builtin_amdgcn_readfirstlane
(
...
@@ -170,13 +168,9 @@ __global__ void
...
@@ -170,13 +168,9 @@ __global__ void
DsPointer
p_ds_grid_grp
;
DsPointer
p_ds_grid_grp
;
//printf("before allocate pointer d");
static_for
<
0
,
NumDTensor
,
1
>
{}(
static_for
<
0
,
NumDTensor
,
1
>
{}(
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
[
&
](
auto
i
)
{
p_ds_grid_grp
(
i
)
=
p_ds_grid
[
i
]
+
ds_batch_offset
[
i
];
});
//printf("before entry");
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
GridwiseOp
::
template
Run
<
HasMainKBlockLoop
>(
p_a_grid
+
a_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_b_grid
+
b_batch_offset
,
p_ds_grid_grp
,
p_ds_grid_grp
,
...
@@ -469,16 +463,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -469,16 +463,23 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
if
(
!
valid
)
if
(
!
valid
)
{
{
printf
(
"GridwiseOp: D descriptor dimension check failure
\n
"
);
return
false
;
return
false
;
}
}
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
if
(
!
(
M
==
e_grid_desc_m_n
.
GetLength
(
I0
)
&&
N
==
e_grid_desc_m_n
.
GetLength
(
I1
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K0
==
b_grid_desc_k0_n_k1
.
GetLength
(
I0
)
&&
K1
==
a_grid_desc_k0_m_k1
.
GetLength
(
I2
)
&&
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
K1
==
b_grid_desc_k0_n_k1
.
GetLength
(
I2
)))
{
printf
(
"GridwiseOp: ABE descriptor dimension cross check failure
\n
"
);
return
false
;
return
false
;
}
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
if
(
!
(
M
%
MPerBlock
==
0
&&
N
%
NPerBlock
==
0
&&
K0
%
K0PerBlock
==
0
))
{
printf
(
"GridwiseOp: Problemsize descriptor dimension check failure
\n
"
);
return
false
;
return
false
;
}
// check gridwise gemm pipeline
// check gridwise gemm pipeline
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
const
auto
num_k_loop
=
K0
/
K0PerBlock
;
...
@@ -570,7 +571,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -570,7 +571,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
const
CDEElementwiseOperation
&
cde_element_op
,
const
CDEElementwiseOperation
&
cde_element_op
,
const
Block2CTileMap
&
block_2_ctile_map
)
const
Block2CTileMap
&
block_2_ctile_map
)
{
{
//printf("safe entry");
// clang-format off
// clang-format off
/*******************************************************************************/
/*******************************************************************************/
// Memory buffer zone.
// Memory buffer zone.
...
@@ -716,7 +716,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
...
@@ -716,7 +716,6 @@ struct GridwiseGemmMultipleD_k0mk1_k0nk1_mn_wmma_cshuffle
c_thread_buf
,
c_thread_buf
,
K0BlockMainLoop
);
K0BlockMainLoop
);
/*******************************************************************************/
/*******************************************************************************/
//printf("safe 1");
// write out to C, implement shuffle
// write out to C, implement shuffle
{
{
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
constexpr
auto
c_thread_desc_mrepeat_mwave_msubgroup_nrepeat_nwave_nthreadpersubgroup_maccvgprs
=
...
...
include/ck/tensor_operation/gpu/thread/threadwise_tensor_slice_transfer.hpp
View file @
27dc055b
...
@@ -1315,6 +1315,7 @@ template <typename SrcData,
...
@@ -1315,6 +1315,7 @@ template <typename SrcData,
index_t
DstScalarPerVector
,
index_t
DstScalarPerVector
,
uint32_t
LowEightRowlaneIdx
,
uint32_t
LowEightRowlaneIdx
,
uint32_t
HighEightRowLaneIdx
,
uint32_t
HighEightRowLaneIdx
,
bool
IntraRowSwizzlePerm
,
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
typename
enable_if
<
SrcDesc
::
IsKnownAtCompileTime
()
&&
DstDesc
::
IsKnownAtCompileTime
(),
bool
>
::
type
=
false
>
bool
>
::
type
=
false
>
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
struct
ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1389,29 +1390,33 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
...
@@ -1389,29 +1390,33 @@ struct ThreadwiseTensorSliceTransfer_StaticToStatic_InterRow
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
constexpr
index_t
dst_offset
=
dst_desc
.
CalculateOffset
(
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
dst_slice_origin_idx
+
idx_md
+
i
*
dst_scalar_step_in_vector
);
SrcData
v
;
SrcData
v_this_row
,
v_theother_row
;
// int type temp value due to intrinsic requirement
int
temp
=
0
;
// apply element-wise operation
// apply element-wise operation
element_op_
(
v
,
src_buf
[
Number
<
src_offset
>
{}]);
element_op_
(
v_this_row
,
src_buf
[
Number
<
src_offset
>
{}]);
// apply intra-row swizzle permute
if
constexpr
(
IntraRowSwizzlePerm
){
// origin: 0xfedcba98, 0x76543210
temp
=
__builtin_amdgcn_permlane16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
0xeca86420
,
0xfdb97531
,
1
,
0
);
v_this_row
=
type_convert
<
float
>
(
temp
);
}
// apply inter-row permute.
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v_this_row
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
v_theother_row
=
type_convert
<
float
>
(
temp
);
if
(
get_thread_local_1d_id
()
%
32
<
16
){
if
(
get_thread_local_1d_id
()
%
32
<
16
){
// apply type convert
// apply type convert
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
}
}
else
{
else
{
// apply type convert
// apply type convert
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v
);
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
v_this_row
);
}
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
v_theother_row
);
SrcData
d
=
0
;
int
temp
=
0
;
temp
=
__builtin_amdgcn_permlanex16
(
temp
,
type_convert
<
int
>
(
v
),
LowEightRowlaneIdx
,
HighEightRowLaneIdx
,
1
,
0
);
d
=
type_convert
<
float
>
(
temp
);
if
(
get_thread_local_1d_id
()
%
32
<
16
){
dst_buf
(
Number
<
dst_offset
+
DstScalarPerVector
>
{})
=
type_convert
<
DstData
>
(
d
);
}
else
{
dst_buf
(
Number
<
dst_offset
>
{})
=
type_convert
<
DstData
>
(
d
);
}
}
});
});
});
});
...
...
include/ck/utility/data_type.hpp
View file @
27dc055b
...
@@ -972,7 +972,6 @@ inline __host__ __device__ constexpr int type_convert<int, float>(float x)
...
@@ -972,7 +972,6 @@ inline __host__ __device__ constexpr int type_convert<int, float>(float x)
float
fp32
;
float
fp32
;
int
int32
;
int
int32
;
}
u
=
{
x
};
}
u
=
{
x
};
// u.fp32 = x;
return
u
.
int32
;
return
u
.
int32
;
}
}
...
@@ -985,7 +984,6 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
...
@@ -985,7 +984,6 @@ inline __host__ __device__ constexpr float type_convert<float, int>(int x)
int
int32
;
int
int32
;
float
fp32
;
float
fp32
;
}
u
=
{
x
};
}
u
=
{
x
};
// u.fp32 = x;
return
u
.
fp32
;
return
u
.
fp32
;
}
}
...
...
library/include/ck/library/utility/host_tensor.hpp
View file @
27dc055b
...
@@ -396,7 +396,7 @@ struct Tensor
...
@@ -396,7 +396,7 @@ struct Tensor
}
}
case
6
:
{
case
6
:
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
i4
,
auto
i5
)
{
auto
f
=
[
&
](
auto
i0
,
auto
i1
,
auto
i2
,
auto
i3
,
auto
i4
,
auto
i5
)
{
(
*
this
)(
i0
,
i1
,
i2
,
i3
,
i4
)
=
g
(
i0
,
i1
,
i2
,
i3
,
i4
,
i5
);
(
*
this
)(
i0
,
i1
,
i2
,
i3
,
i4
,
i5
)
=
g
(
i0
,
i1
,
i2
,
i3
,
i4
,
i5
);
};
};
make_ParallelTensorFunctor
(
f
,
make_ParallelTensorFunctor
(
f
,
mDesc
.
GetLengths
()[
0
],
mDesc
.
GetLengths
()[
0
],
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment