"...composable_kernel_rocm.git" did not exist on "bf0addb5753fb44e39e33e0edfa2158d6f1ffce7"
Commit 160cf6ed authored by Muhammed Ozturk's avatar Muhammed Ozturk
Browse files

update

parent 7582c18e
......@@ -35,7 +35,38 @@ using namespace std;
__constant__ int const_internal_t2_offset[MAX_CONST_LEN];
__constant__ int const_internal_v2_offset[MAX_CONST_LEN];
// created by tc_gen_code_Kernel()
struct Complex
{
/* data */
float re;
float im;
};
__device__ Complex ComplexMul(Complex a, Complex b)
{
Complex c;
c.re = a.re * b.re - a.im * b.im ;
c.im = a.re * b.im + a.im * b.re ;
return c;
}
__device__ Complex ComplexAdd(Complex a, Complex b)
{
Complex c;
c.re = a.re + b.re ;
c.im = a.im + b.im ;
return c;
}
__global__ void kernel__1_1(float* dev_t3,
float* dev_t2,
float* dev_v2,
......@@ -269,17 +300,17 @@ int size_internal)
}
// created by tc_gen_code_Kernel()
__global__ void kernel__3_1(float* dev_t3,
float* dev_t2,
float* dev_v2,
__global__ void kernel__3_1(Complex* dev_t3,
Complex* dev_t2,
Complex* dev_v2,
int size_a, int size_b, int size_c, int size_d, int size_e, int size_f,
int numBlk_a, int numBlk_b, int numBlk_c, int numBlk_d,
int stride_reg_x, int stride_reg_y,
int size_internal)
{
// For Shared Memory,
__shared__ float sm_a[16][96];
__shared__ float sm_b[16][96];
__shared__ Complex sm_a[16][96];
__shared__ Complex sm_b[16][96];
// when opt_pre_computed == -1, all indices will be calculated manually
......@@ -337,13 +368,18 @@ int size_internal)
rng_d = size_d % SIZE_SLICE_1_D;
}
float temp_av;
float temp_bv[6];
float reg_tile[6][6];
Complex temp_av;
Complex temp_bv[6];
Complex reg_tile[6][6];
for (int i = 0; i < 6; i++)
for (int j = 0; j < 6; j++)
reg_tile[i][j] = 0.0;
for (int i = 0; i < 6; i++){
for (int j = 0; j < 6; j++){
reg_tile[i][j].re = 0.0;
reg_tile[i][j].im
}
}
// tensor contraction: [[16, 'STR_SD2_T2_H7', 'x', 't2', ['a', 'e', 'b', 'f']], [16, 'STR_SD2_V2_H7', 'y', 'v2', ['d', 'f', 'c', 'e']], '+=']
#pragma unroll 1
......@@ -352,7 +388,7 @@ int size_internal)
//---------------------------------------------------------------------------------------------------
// This is for the new version
// This Part is for Loading Input-Left
// tc_gen_code_Kernel_Load_Inputs_Abstracts()
if (idx_a < rng_a)
for (int ll = 0; ll < rng_b; ll++)
{
......@@ -363,7 +399,7 @@ int size_internal)
}
// This Part is for Loading Input-Right
// tc_gen_code_Kernel_Load_Inputs_Abstracts()
if (idx_a < rng_d)
for (int ll = 0; ll < rng_c; ll++)
{
......@@ -390,12 +426,21 @@ int size_internal)
{
temp_av = sm_a[ll][idx_a + (xx * 16)];
reg_tile[0][xx] += temp_av * temp_bv[0];
reg_tile[1][xx] += temp_av * temp_bv[1];
reg_tile[2][xx] += temp_av * temp_bv[2];
reg_tile[3][xx] += temp_av * temp_bv[3];
reg_tile[4][xx] += temp_av * temp_bv[4];
reg_tile[5][xx] += temp_av * temp_bv[5];
// reg_tile[0][xx] += temp_av * temp_bv[0];
// reg_tile[1][xx] += temp_av * temp_bv[1];
// reg_tile[2][xx] += temp_av * temp_bv[2];
// reg_tile[3][xx] += temp_av * temp_bv[3];
// reg_tile[4][xx] += temp_av * temp_bv[4];
// reg_tile[5][xx] += temp_av * temp_bv[5];
reg_tile[0][xx] = ComplexAdd(reg_tile[0][xx] , ComplexMul(temp_av, temp_bv[0] )) ;
reg_tile[1][xx] = ComplexAdd(reg_tile[1][xx] , ComplexMul(temp_av, temp_bv[1] )) ;
reg_tile[2][xx] = ComplexAdd(reg_tile[2][xx] , ComplexMul(temp_av, temp_bv[2] )) ;
reg_tile[3][xx] = ComplexAdd(reg_tile[3][xx] , ComplexMul(temp_av, temp_bv[3] )) ;
reg_tile[4][xx] = ComplexAdd(reg_tile[4][xx] , ComplexMul(temp_av, temp_bv[4] )) ;
reg_tile[5][xx] = ComplexAdd(reg_tile[5][xx] , ComplexMul(temp_av, temp_bv[5] )) ;
}
}
__syncthreads();
......@@ -689,7 +734,7 @@ int size_internal)
}
}
// created by tc_gen_code_Kernel()
__global__ void kernel__2_tex_1(float* dev_t3,
float* dev_t2,
float* dev_v2,
......@@ -747,7 +792,7 @@ int size_internal)
//---------------------------------------------------------------------------------------------------
// This is for the new version
// This Part is for Loading Input-Left
// tc_gen_code_Kernel_Load_Inputs_Abstracts()
if (threadIdx.y < SIZE_INT_UNIT_1 - internal_upperbound)
for (int ll = 0; ll < 6; ll++)
{
......@@ -1119,7 +1164,7 @@ int size_internal)
}
}
// written by tc_interface.tc_gen_code_interface_Header()
extern "C"
void sd_t_d2_fusion(int size_a, int size_b, int size_c, int size_d, int size_e, int size_f, float* t3, float* host_t2, float* host_v2, int cond_kernel_1, int opt_register_transpose)
{
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment