#include "xpu_kernels.h" #include #include #include #include inline float dDequantizeFP4(unsigned char val) { if ((val & 0b1000) == 8) if ((val & 0b0100) == 4) if ((val & 0b0010) == 2) if ((val & 0b0001) == 1) return -0.25000000f; else return -0.16666667f; else if ((val & 0b0001) == 1) return -0.50000000f; else return -0.33333333f; else if ((val & 0b0010) == 2) if ((val & 0b0001) == 1) return -1.00000000f; else return -0.66666667f; else if ((val & 0b0001) == 1) return -5.208333333e-03f; else return 0.00000000f; else if ((val & 0b0100) == 4) if ((val & 0b0010) == 2) if ((val & 0b0001) == 1) return 0.25000000f; else return 0.16666667f; else if ((val & 0b0001) == 1) return 0.50000000f; else return 0.33333333f; else if ((val & 0b0010) == 2) if ((val & 0b0001) == 1) return 1.00000000f; else return 0.66666667f; else if ((val & 0b0001) == 1) return 5.208333333e-03f; else return 0.00000000f; } inline float dDequantizeNF4(unsigned char val) { // the values for this tree was generated by test_normal_map_tree // in the file tests/test_functional.py if ((val & 0b1000) == 8) if ((val & 0b0100) == 4) // 1 if ((val & 0b0010) == 2) // 11 if ((val & 0b0001) == 1) // 111 return 1.0f; //*1111 else return 0.7229568362236023f; //*1110 else if ((val & 0b0001) == 1) // 110 return 0.5626170039176941f; //*1101 else return 0.44070982933044434f; //*1100 else if ((val & 0b0010) == 2) // 10 if ((val & 0b0001) == 1) // 101 return 0.33791524171829224f; //*1011 else return 0.24611230194568634f; //*1010 else if ((val & 0b0001) == 1) // 100 return 0.16093020141124725f; //*1001 else return 0.07958029955625534f; //*1000 else if ((val & 0b0100) == 4) // 0 if ((val & 0b0010) == 2) // 01 if ((val & 0b0001) == 1) // 011 return 0.0f; //*0111 else return -0.09105003625154495f; //*0110 else if ((val & 0b0001) == 1) // 010 return -0.18477343022823334f; //*0101 else return -0.28444138169288635f; //*0100 else if ((val & 0b0010) == 2) // 00 if ((val & 0b0001) == 1) // 001 return -0.39491748809814453f; //*0011 else return -0.5250730514526367f; //*0010 else if ((val & 0b0001) == 1) // 000 return -0.6961928009986877f; //*0001 else return -1.0f; //*0000 } template SYCL_EXTERNAL void kDequantizeBlockwise::operator()(sycl::nd_item<1> item) const { const int base_idx = item.get_group(0) * TILE_SIZE; size_t local_idx = item.get_local_id(0) * NUM_PER_TH; float local_abs_max = -FLT_MAX; int local_load_idx = 0; int local_store_idx = 0; uint8_t qvals[NUM_PER_TH]; T vals[NUM_PER_TH * ((DATA_TYPE > 0) ? 2 : 1)]; if (DATA_TYPE > 0) { local_load_idx = sycl::min(TILE_SIZE, (n + 1) / 2 - base_idx); local_store_idx = sycl::min(TILE_SIZE * 2, n - base_idx * 2); } else { local_load_idx = sycl::min(TILE_SIZE, n - base_idx); local_store_idx = local_load_idx; } // Avoid expensive division by the blocksize (as blocksize will always be a // power-of-2) local_abs_max = absmax[(base_idx + local_idx) >> (31 - std::countl_zero(blocksize))]; if (local_idx + NUM_PER_TH < local_load_idx) { reinterpret_cast(&)[NUM_PER_TH]>(qvals)[0] = reinterpret_cast*>(A)[(base_idx + local_idx) / NUM_PER_TH]; } else { #pragma unroll NUM_PER_TH for (int i = 0; i < NUM_PER_TH; i++) { if (local_idx + i < local_load_idx) { qvals[i] = A[base_idx + local_idx + i]; } else { qvals[i] = (uint8_t)0; } } } switch (DATA_TYPE) { case General8bit: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) vals[j] = code[qvals[j]] * local_abs_max; break; case FP4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { vals[j * 2] = dDequantizeFP4(qvals[j] >> 4) * local_abs_max; vals[j * 2 + 1] = dDequantizeFP4(qvals[j] & 0x0F) * local_abs_max; } break; case NF4: #pragma unroll NUM_PER_TH for (int j = 0; j < NUM_PER_TH; j++) { vals[j * 2] = dDequantizeNF4(qvals[j] >> 4) * local_abs_max; vals[j * 2 + 1] = dDequantizeNF4(qvals[j] & 0x0F) * local_abs_max; } break; } const int local_dst_size = (DATA_TYPE > 0) ? NUM_PER_TH * 2 : NUM_PER_TH; int local_dst_idx = (DATA_TYPE > 0) ? local_idx * 2 : local_idx; if (local_dst_idx + local_dst_size < local_store_idx) { reinterpret_cast*>( out )[(((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx) / local_dst_size] = reinterpret_cast(&)[local_dst_size]>(vals)[0]; } else { #pragma unroll NUM_PER_TH for (int i = 0; i < local_dst_size; i++) { if (local_dst_idx + i < local_store_idx) { out[((DATA_TYPE > 0) ? base_idx * 2 : base_idx) + local_dst_idx + i] = vals[i]; } } } } template SYCL_EXTERNAL void kgemv_4bit_inference::operator()(sycl::nd_item<1> item) const { size_t idx = item.get_local_id(); const int sg_idx = idx / SUBG_SIZE; const int sg_lane = idx % SUBG_SIZE; const int num_values_4bit = SUBG_SIZE; const int row_B = NUM_PER_THREAD * item.get_group().get_group_id() + sg_idx; const int offset_B = ldb * row_B; const int num_values_8bit = num_values_4bit / 2; float local_C = 0.0f; unsigned char local_B_4bit[num_values_8bit]; T local_B[num_values_4bit / 4]; T local_A[num_values_4bit / 4]; T local_absmax = T(0.0f); if (idx < 16) { quant_map[idx] = T(datatype[idx]); } item.barrier(sycl::access::fence_space::local_space); for (int inner_idx = sg_lane * num_values_4bit; inner_idx < K; inner_idx += SUBG_SIZE * num_values_4bit) { const int inner_idx_halved = inner_idx / 2; // Avoid expensive division by the blocksize (as blocksize will always be a // power-of-2) const int absidx = ((2 * offset_B) + inner_idx) >> (31 - std::countl_zero((unsigned int)blocksize)); local_absmax = absmax[absidx]; if (row_B < N) { if ((inner_idx_halved + num_values_8bit) < (K / 2)) { reinterpret_cast(&)[num_values_8bit]>(local_B_4bit)[0] = reinterpret_cast*>(B)[(offset_B + (inner_idx_halved)) / (num_values_8bit)]; } else { #pragma unroll for (int j = 0; j < (num_values_8bit); j++) if ((inner_idx_halved) + j < (K / 2)) local_B_4bit[j] = B[offset_B + inner_idx_halved + j]; else local_B_4bit[j] = 0b01110111; } } else { #pragma unroll for (int j = 0; j < (num_values_8bit); j++) local_B_4bit[j] = 0b01110111; } for (int i = 0; i < 4; i++) { #pragma unroll for (int k = 0; k < num_values_8bit / 4; k++) { local_B[k * 2] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] >> 4] * local_absmax; local_B[k * 2 + 1] = quant_map[local_B_4bit[(i * num_values_8bit / 4) + k] & 0x0F] * local_absmax; } if (inner_idx + (num_values_4bit / 4) + (i * num_values_4bit / 4) < K) { if (BITS == 16) { reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 4) + i]; } else { reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[0] = reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 0]; reinterpret_cast(&)[num_values_4bit / 4]>(local_A)[1] = reinterpret_cast*>(A)[inner_idx / (num_values_4bit / 8) + (2 * i) + 1]; } } else { #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) if (inner_idx + (i * num_values_4bit / 4) + k < K) local_A[k] = A[inner_idx + k + (i * num_values_4bit / 4)]; else local_A[k] = T(0.0f); } // accumulate in float for accuracy; #pragma unroll for (int k = 0; k < num_values_4bit / 4; k++) { local_C += (float)(local_A[k] * local_B[k]); } } } local_C = sycl::reduce_over_group(item.get_sub_group(), local_C, sycl::plus<>()); if (row_B < N && sg_lane == 0) out[row_B] = T(local_C); } //============================================================== // TEMPLATE DEFINITIONS //============================================================== template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kDequantizeBlockwise; template class kgemv_4bit_inference; template class kgemv_4bit_inference; template class kgemv_4bit_inference;