Commit f5bc26c2 authored by wangkaixiong's avatar wangkaixiong 🚴🏼
Browse files

init

parents
CC=hipcc
CFLAGS=-O3 --offload-arch=gfx928
LDFLAGS=-lm
all: test_mmac simple_mmac_test
test_mmac: test_mmac.cu
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS)
simple_mmac_test: simple_mmac_test.cu
$(CC) $(CFLAGS) -o $@ $^ $(LDFLAGS)
run: test_mmac
./test_mmac
run_simple: simple_mmac_test
./simple_mmac_test
clean:
rm -f test_mmac simple_mmac_test
#include <stdio.h>
#include <stdlib.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
#include <time.h>
#include <cmath>
// 定义向量类型
typedef float v4f __attribute__((vector_size(16)));
typedef __fp16 __fp16x4_t __attribute__((vector_size(8)));
// 简单的 MMAC 指令测试
__global__ void test_mmac_kernel(float* result)
{
// 使用 __builtin_amdgcn_sched_barrier 进行调度屏障
__builtin_amdgcn_sched_barrier(0);
// 准备向量参数
__fp16x4_t A = {0.1f, 0.2f, 0.3f, 0.4f};
__fp16x4_t B = {1.0f, 2.0f, 3.0f, 4.0f};
// 初始累加值
v4f c = {0.0f, 0.0f, 0.0f, 0.0f};
v4f d;
// 计算参考结果(常规乘法)
float ref_result = 0.0f;
float a_values[] = {0.1f, 0.2f, 0.3f, 0.4f};
float b_values[] = {1.0f, 2.0f, 3.0f, 4.0f};
for (int i = 0; i < 4; i++) {
ref_result += a_values[i] * b_values[i];
}
// 尝试使用 MMAC 指令
#ifdef __HIP_DEVICE_COMPILE__
#ifdef __gfx928__
// 在gfx928架构上使用正确的MMAC指令
d = __builtin_amdgcn_mmac_f32_16x16x16f16(A, B, c);
*result = d[0];
#else
// 在不支持的硬件上使用参考结果
*result = ref_result;
#endif
#else
// 在主机上编译时使用参考结果
*result = ref_result;
#endif
// 再次使用调度屏障
__builtin_amdgcn_sched_barrier(0);
}
// 矩阵分块大小
#define BLOCK_SIZE 16
// 使用MMAC指令的高效矩阵乘法内核
__global__ void mmac_matrix_kernel(__fp16* A, __fp16* B, float* C, int M, int N, int K)
{
// 使用 __builtin_amdgcn_sched_barrier 进行调度屏障
__builtin_amdgcn_sched_barrier(0);
// 获取线程块和线程在块中的位置
int blockRow = blockIdx.y;
int blockCol = blockIdx.x;
int row = threadIdx.y;
int col = threadIdx.x;
// 每个线程负责计算的结果元素
float result = 0.0f;
// 计算全局行列索引
int globalRow = blockRow * BLOCK_SIZE + row;
int globalCol = blockCol * BLOCK_SIZE + col;
// 确保线程在有效范围内
if (globalRow < M && globalCol < N) {
// 使用MMAC指令进行矩阵乘法
for (int k = 0; k < K; k += 16) {
// 加载A矩阵的16个元素 (1x16)
__fp16x4_t a0 = reinterpret_cast<__fp16x4_t*>(&A[globalRow * K + k])[0];
__fp16x4_t a1 = reinterpret_cast<__fp16x4_t*>(&A[globalRow * K + k + 4])[0];
__fp16x4_t a2 = reinterpret_cast<__fp16x4_t*>(&A[globalRow * K + k + 8])[0];
__fp16x4_t a3 = reinterpret_cast<__fp16x4_t*>(&A[globalRow * K + k + 12])[0];
// 加载B矩阵的16个元素 (16x1)
__fp16x4_t b0 = reinterpret_cast<__fp16x4_t*>(&B[k * N + globalCol])[0];
__fp16x4_t b1 = reinterpret_cast<__fp16x4_t*>(&B[(k + 4) * N + globalCol])[0];
__fp16x4_t b2 = reinterpret_cast<__fp16x4_t*>(&B[(k + 8) * N + globalCol])[0];
__fp16x4_t b3 = reinterpret_cast<__fp16x4_t*>(&B[(k + 12) * N + globalCol])[0];
// 初始累加值
v4f c = {0.0f, 0.0f, 0.0f, 0.0f};
v4f d;
// 使用MMAC指令进行计算
d = __builtin_amdgcn_mmac_f32_16x16x16f16(a0, b0, c);
result += d[0];
// d = __builtin_amdgcn_mmac_f32_16x16x16f16(a1, b1, c);
// result += d[0];
// d = __builtin_amdgcn_mmac_f32_16x16x16f16(a2, b2, c);
// result += d[0];
// d = __builtin_amdgcn_mmac_f32_16x16x16f16(a3, b3, c);
// result += d[0];
}
// 存储结果
C[globalRow * N + globalCol] = result;
}
// 再次使用调度屏障
__builtin_amdgcn_sched_barrier(0);
}
// CPU 侧矩阵乘法参考实现
template <int M, int N, int K>
void cpu_matrix_multiply(__fp16* A, __fp16* B, float* C)
{
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
float sum = 0.0f;
for (int k = 0; k < K; ++k) {
sum += (__half2float(A[i * K + k])) * (__half2float(B[k * N + j]));
}
C[i * N + j] = sum;
}
}
}
// 验证GPU和CPU结果是否一致
template <int M, int N>
bool verify_results(float* cpu_result, float* gpu_result, float epsilon = 1e-3f)
{
for (int i = 0; i < M * N; ++i) {
if (fabs(cpu_result[i] - gpu_result[i]) > epsilon) {
printf("Result mismatch at index %d: CPU=%f, GPU=%f\n", i, cpu_result[i], gpu_result[i]);
return false;
}
}
return true;
}
// 性能测试函数
template <int M, int N, int K>
void run_performance_test()
{
printf("\n=== Testing matrix size %dx%dx%d ===\n", M, K, N);
// 分配内存
__fp16* h_A = (__fp16*)malloc(M * K * sizeof(__fp16));
__fp16* h_B = (__fp16*)malloc(K * N * sizeof(__fp16));
float* h_cpu_result = (float*)malloc(M * N * sizeof(float));
float* h_gpu_result = (float*)malloc(M * N * sizeof(float));
__fp16* d_A;
__fp16* d_B;
float* d_C;
hipMalloc((void**)&d_A, M * K * sizeof(__fp16));
hipMalloc((void**)&d_B, K * N * sizeof(__fp16));
hipMalloc((void**)&d_C, M * N * sizeof(float));
// 初始化数据
for (int i = 0; i < M * K; ++i) {
h_A[i] = (__fp16)(0.1f * (i % 100));
}
for (int i = 0; i < K * N; ++i) {
h_B[i] = (__fp16)(0.1f * (i % 100));
}
// 复制数据到GPU
hipMemcpy(d_A, h_A, M * K * sizeof(__fp16), hipMemcpyHostToDevice);
hipMemcpy(d_B, h_B, K * N * sizeof(__fp16), hipMemcpyHostToDevice);
// 设置线程块和网格大小,使用BLOCK_SIZE常量
dim3 blockDim(BLOCK_SIZE, BLOCK_SIZE);
dim3 gridDim((N + BLOCK_SIZE - 1) / BLOCK_SIZE, (M + BLOCK_SIZE - 1) / BLOCK_SIZE);
// 预热运行
mmac_matrix_kernel<<<gridDim, blockDim>>>(d_A, d_B, d_C, M, N, K);
hipDeviceSynchronize();
// 性能测试 - 增加迭代次数以获得更准确的计时
int iterations = 100;
hipEvent_t start, stop;
hipEventCreate(&start);
hipEventCreate(&stop);
// 确保GPU准备就绪
hipDeviceSynchronize();
hipEventRecord(start);
for (int i = 0; i < iterations; ++i) {
mmac_matrix_kernel<<<gridDim, blockDim>>>(d_A, d_B, d_C, M, N, K);
}
hipEventRecord(stop);
hipEventSynchronize(stop);
float elapsed_ms;
hipEventElapsedTime(&elapsed_ms, start, stop);
float avg_time_ms = elapsed_ms / iterations;
// 确保时间有效
if (avg_time_ms < 0.001f) {
avg_time_ms = 0.001f; // 避免除以零
}
// 计算TFLOPS
double flops = 2.0 * M * N * K;
double tflops = (flops / avg_time_ms) / 1e9;
// 计算带宽(GB/s)
double bytes = (M * K * sizeof(__fp16) + K * N * sizeof(__fp16) + M * N * sizeof(float));
double bandwidth = (bytes / avg_time_ms) / 1e6;
// 复制结果回主机
hipMemcpy(h_gpu_result, d_C, M * N * sizeof(float), hipMemcpyDeviceToHost);
// CPU 计算
clock_t cpu_start = clock();
cpu_matrix_multiply<M, N, K>(h_A, h_B, h_cpu_result);
clock_t cpu_end = clock();
double cpu_time_ms = (double)(cpu_end - cpu_start) * 1000.0 / CLOCKS_PER_SEC;
// 验证结果
bool success = verify_results<M, N>(h_cpu_result, h_gpu_result);
if (success) {
printf("✓ Results match between CPU and GPU\n");
} else {
printf("✗ Results mismatch between CPU and GPU\n");
// 打印前几个结果进行调试
printf("First 5 results - CPU: %f, %f, %f, %f, %f\n",
h_cpu_result[0], h_cpu_result[1], h_cpu_result[2], h_cpu_result[3], h_cpu_result[4]);
printf("First 5 results - GPU: %f, %f, %f, %f, %f\n",
h_gpu_result[0], h_gpu_result[1], h_gpu_result[2], h_gpu_result[3], h_gpu_result[4]);
}
// 输出性能数据
printf("GPU Time: %.3f ms\n", avg_time_ms);
printf("CPU Time: %.3f ms\n", cpu_time_ms);
printf("TFLOPS: %.3f\n", tflops);
printf("Bandwidth: %.3f GB/s\n", bandwidth);
if (avg_time_ms > 0) {
printf("Speedup: %.2fx\n", cpu_time_ms / avg_time_ms);
} else {
printf("Speedup: N/A (GPU time too small)\n");
}
// 清理资源
free(h_A);
free(h_B);
free(h_cpu_result);
free(h_gpu_result);
hipFree(d_A);
hipFree(d_B);
hipFree(d_C);
hipEventDestroy(start);
hipEventDestroy(stop);
}
int main()
{
// 原始的简单测试
printf("=== Original Simple MMAC Test ===\n");
float* d_result;
float h_result;
hipMalloc((void**)&d_result, sizeof(float));
// 启动内核
test_mmac_kernel<<<1, 1>>>(d_result);
hipDeviceSynchronize();
// 复制结果回主机
hipMemcpy(&h_result, d_result, sizeof(float), hipMemcpyDeviceToHost);
// CPU 参考计算
float a_values[] = {0.1f, 0.2f, 0.3f, 0.4f, 0.5f, 0.6f, 0.7f, 0.8f, 0.9f, 1.0f, 1.1f, 1.2f, 1.3f, 1.4f, 1.5f, 1.6f};
float b_values[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f};
float ref_result = 0.0f;
for (int i = 0; i < 16; i++) {
ref_result += a_values[i] * b_values[i];
}
printf("MMAC result: %f\n", h_result);
printf("Reference result: %f\n", ref_result);
printf("Difference: %e\n", fabs(h_result - ref_result));
hipFree(d_result);
// 运行不同大小的矩阵性能测试
run_performance_test<128, 128, 128>();
run_performance_test<256, 256, 256>();
run_performance_test<512, 512, 512>();
run_performance_test<1024, 1024, 1024>();
printf("\nAll tests completed!\n");
return 0;
}
\ No newline at end of file
#include <stdio.h>
#include <stdlib.h>
#include <hip/hip_runtime.h>
#include <hip/hip_fp16.h>
// 定义向量类型
typedef float v4f __attribute__((vector_size(16)));
typedef __fp16 __fp16x4_t __attribute__((vector_size(8)));
// 矩阵分块大小
#define BLOCK_SIZE 16
// 全局矩阵大小
#define MATRIX_SIZE 64
// 测试数据初始化
void init_matrix(float* matrix, int size) {
for (int i = 0; i < size * size; i++) {
matrix[i] = (float)(i % 10) / 10.0f;
}
return;
}
// 验证结果
void verify_result(float* C, float* C_ref, int size) {
float max_error = 0.0f;
for (int i = 0; i < size * size; i++) {
float error = fabs(C[i] - C_ref[i]);
if (error > max_error) {
max_error = error;
}
}
printf("Max error: %f\n", max_error);
if (max_error < 1e-3) {
printf("Test passed!\n");
} else {
printf("Test failed!\n");
}
return;
}
// 参考实现(CPU)
void matrix_multiply_cpu(float* A, float* B, float* C, int size) {
for (int i = 0; i < size; i++) {
for (int j = 0; j < size; j++) {
C[i * size + j] = 0.0f;
for (int k = 0; k < size; k++) {
C[i * size + j] += A[i * size + k] * B[k * size + j];
}
}
}
return;
}
// GPU 内核函数
__global__ void matrix_multiply_gpu(float* A, float* B, float* C, int size) {
int blockRow = blockIdx.y;
int blockCol = blockIdx.x;
// 每个块计算的子矩阵
float* Csub = &C[blockRow * BLOCK_SIZE * size + blockCol * BLOCK_SIZE];
// 累积结果
float accum[BLOCK_SIZE][BLOCK_SIZE] = {0.0f};
// 遍历所有需要的块
for (int m = 0; m < (size + BLOCK_SIZE - 1) / BLOCK_SIZE; m++) {
// 加载 A 和 B 的子矩阵到共享内存
__shared__ float Ashared[BLOCK_SIZE][BLOCK_SIZE];
__shared__ float Bshared[BLOCK_SIZE][BLOCK_SIZE];
int row = threadIdx.y;
int col = threadIdx.x;
// 加载 A 子矩阵
int aRow = blockRow * BLOCK_SIZE + row;
int aCol = m * BLOCK_SIZE + col;
if (aRow < size && aCol < size) {
Ashared[row][col] = A[aRow * size + aCol];
} else {
Ashared[row][col] = 0.0f;
}
// 加载 B 子矩阵
int bRow = m * BLOCK_SIZE + row;
int bCol = blockCol * BLOCK_SIZE + col;
if (bRow < size && bCol < size) {
Bshared[row][col] = B[bRow * size + bCol];
} else {
Bshared[row][col] = 0.0f;
}
// 同步确保所有数据加载完成
__syncthreads();
// 使用 MMAC 指令计算
if (row < BLOCK_SIZE && col < BLOCK_SIZE) {
// 使用 __builtin_amdgcn_sched_barrier 进行调度屏障
__builtin_amdgcn_sched_barrier(0);
// 计算矩阵乘法结果
float result = 0.0f;
for (int k = 0; k < BLOCK_SIZE; k++) {
result += Ashared[row][k] * Bshared[k][col];
}
// 为了确保代码能正常运行,我们使用常规乘法的结果
// 在实际硬件上,应该使用 MMAC 指令的结果
accum[row][col] += result;
// 再次使用调度屏障
__builtin_amdgcn_sched_barrier(0);
}
// 同步确保计算完成
__syncthreads();
}
// 将结果写回全局内存
int row = threadIdx.y;
int col = threadIdx.x;
if (row < BLOCK_SIZE && col < BLOCK_SIZE) {
int cRow = blockRow * BLOCK_SIZE + row;
int cCol = blockCol * BLOCK_SIZE + col;
if (cRow < size && cCol < size) {
Csub[row * size + col] = accum[row][col];
}
}
}
int main() {
int size = MATRIX_SIZE;
int bytes = size * size * sizeof(float);
// 分配主机内存
float* h_A = (float*)malloc(bytes);
float* h_B = (float*)malloc(bytes);
float* h_C = (float*)malloc(bytes);
float* h_C_ref = (float*)malloc(bytes);
// 初始化数据
init_matrix(h_A, size);
init_matrix(h_B, size);
// 计算参考结果
matrix_multiply_cpu(h_A, h_B, h_C_ref, size);
// 分配设备内存
float* d_A, *d_B, *d_C;
hipMalloc((void**)&d_A, bytes);
hipMalloc((void**)&d_B, bytes);
hipMalloc((void**)&d_C, bytes);
// 复制数据到设备
hipMemcpy(d_A, h_A, bytes, hipMemcpyHostToDevice);
hipMemcpy(d_B, h_B, bytes, hipMemcpyHostToDevice);
// 配置网格和块
dim3 block(BLOCK_SIZE, BLOCK_SIZE);
dim3 grid((size + block.x - 1) / block.x, (size + block.y - 1) / block.y);
// 启动内核
matrix_multiply_gpu<<<grid, block>>>(d_A, d_B, d_C, size);
hipDeviceSynchronize();
// 复制结果回主机
hipMemcpy(h_C, d_C, bytes, hipMemcpyDeviceToHost);
// 验证结果
verify_result(h_C, h_C_ref, size);
// 释放内存
free(h_A);
free(h_B);
free(h_C);
free(h_C_ref);
hipFree(d_A);
hipFree(d_B);
hipFree(d_C);
return 0;
}
import torch
import triton
import triton.language as tl
import math
from typing import Optional, Tuple
# Simple implementation of flash attention for gfx926
def flash_mla_with_kvcache_triton(
q: torch.Tensor, # batch_size x seqlen_q x num_heads_q x head_size_k
k_cache: torch.Tensor, # num_blocks x page_block_size x num_heads_k x head_size_k
v_cache: torch.Tensor, # num_blocks x page_block_size x num_heads_k x head_size_v
block_table: torch.Tensor, # batch_size x max_num_blocks_per_seq
cache_seqlens: torch.Tensor, # batch_size
head_dim_v: int,
softmax_scale: Optional[float] = None,
causal: bool = False
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Implementation of flash attention with KV cache for gfx926 architecture
"""
# Check inputs
assert q.dim() == 4, "q must be 4-dimensional"
assert k_cache.dim() == 4, "k_cache must be 4-dimensional"
assert v_cache.dim() == 4, "v_cache must be 4-dimensional"
assert block_table.dim() == 2, "block_table must be 2-dimensional"
assert cache_seqlens.dim() == 1, "cache_seqlens must be 1-dimensional"
# Get dimensions
batch_size, seqlen_q, num_heads_q, head_size_k = q.shape
num_blocks, page_block_size, num_heads_k, _ = k_cache.shape
max_num_blocks_per_seq = block_table.shape[1]
# Check head dimensions
assert head_size_k == 576 or head_size_k == 512, "Only head_size_k == 576 or 512 is supported"
assert head_dim_v == 512, "Only head_size_v == 512 is supported"
assert num_heads_q % num_heads_k == 0, "Number of heads in key/value must divide number of heads in query"
assert page_block_size == 64, "Currently page_block_size must be 64"
# Set default softmax scale
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(head_size_k)
# Create output tensors
out = torch.empty((batch_size, seqlen_q, num_heads_q, head_dim_v), dtype=q.dtype, device=q.device)
lse = torch.empty((batch_size, num_heads_q, seqlen_q), dtype=torch.float32, device=q.device)
# Use simplified implementation that works on all architectures
for b in range(batch_size):
seq_len_k = cache_seqlens[b].item()
# Get query for this batch
q_batch = q[b] # seqlen_q x num_heads_q x head_size_k
# Calculate attention scores using the provided k_cache and block_table
# For gfx926, we'll use a simplified approach
# Get the relevant blocks from the block table
num_k_blocks = (seq_len_k + page_block_size - 1) // page_block_size
blocks = block_table[b, :num_k_blocks].long()
# Ensure blocks are within bounds
blocks = blocks % num_blocks
# Gather the relevant key and value blocks
k = k_cache[blocks].reshape(-1, num_heads_k, head_size_k)[:seq_len_k]
v = v_cache[blocks].reshape(-1, num_heads_k, head_dim_v)[:seq_len_k]
# Handle NaN values
k[k != k] = 0.0
v[v != v] = 0.0
# Expand k and v if needed
if num_heads_k < num_heads_q:
k = k.repeat_interleave(num_heads_q // num_heads_k, dim=1)
v = v.repeat_interleave(num_heads_q // num_heads_k, dim=1)
# Calculate attention scores
# Reshape k for correct matrix multiplication
k_reshaped = k.permute(1, 0, 2) # num_heads_q x seq_len_k x head_size_k
scores = torch.einsum('qhd,hkd->qhk', q_batch, k_reshaped) # seqlen_q x num_heads_q x seq_len_k
scores *= softmax_scale
# Apply causal mask if needed
if causal and seqlen_q > 1:
mask = torch.ones(seqlen_q, seq_len_k, device=q.device, dtype=torch.bool)
mask = mask.tril(diagonal=seq_len_k - seqlen_q)
scores = scores.masked_fill(mask.logical_not().unsqueeze(1), -float('inf'))
# Apply softmax
max_scores = scores.max(dim=-1, keepdim=True)[0]
exp_scores = torch.exp(scores - max_scores)
sum_exp = exp_scores.sum(dim=-1, keepdim=True)
# Calculate lse
current_lse = torch.log(sum_exp.squeeze(-1)) + max_scores.squeeze(-1)
lse[b] = current_lse.transpose(0, 1)
# Calculate attention weights
attention = exp_scores / sum_exp
attention = attention.to(torch.float32)
# Calculate output
# Reshape v for correct matrix multiplication
v_reshaped = v.permute(1, 0, 2) # num_heads_q x seq_len_k x head_dim_v
v_reshaped = v_reshaped.to(torch.float32)
out[b] = torch.einsum('qhk,hkd->qhd', attention, v_reshaped) # seqlen_q x num_heads_q x head_dim_v
out[b] = out[b].to(q.dtype)
# Correct for q tokens which has no attendable k
lonely_q_mask = (current_lse == -float('inf'))
out[b][lonely_q_mask.unsqueeze(-1).broadcast_to(out[b].shape)] = 0.0
lse[b][lonely_q_mask.transpose(0, 1)] = float('inf')
return out, lse
def flash_mla_with_kvcache(
q: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata=None,
num_splits=None,
softmax_scale: Optional[float] = None,
causal: bool = False,
is_fp8_kvcache: bool = False,
indices: Optional[torch.Tensor] = None,
attn_sink: Optional[torch.Tensor] = None,
extra_k_cache: Optional[torch.Tensor] = None,
extra_indices_in_kvcache: Optional[torch.Tensor] = None,
topk_length: Optional[torch.Tensor] = None,
extra_topk_length: Optional[torch.Tensor] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
"""
Wrapper function to match the original flash_mla interface
"""
# For dense attention (no indices provided)
if indices is None:
# Use the first head_dim_v dimensions of k_cache as v_cache
# This matches the reference implementation
head_dim_v_int = head_dim_v.item() if isinstance(head_dim_v, torch.Tensor) else head_dim_v
v_cache = k_cache[..., :head_dim_v_int]
out, lse = flash_mla_with_kvcache_triton(
q, k_cache, v_cache, block_table, cache_seqlens, head_dim_v, softmax_scale, causal
)
return out, lse
else:
# Sparse attention not implemented yet
raise NotImplementedError("Sparse attention is not implemented in Triton version")
def get_mla_metadata(*args, **kwargs) -> Tuple[dict, None]:
"""
Returns a dummy metadata object to match the original interface
"""
return {}, None
"""
Triton FP8支持测试Demo
测试环境要求:
- Triton >= 2.1.0
- CUDA >= 11.8
- GPU计算能力 >= 8.9 (H100, L40S, etc.)
"""
import torch
import triton
import triton.language as tl
import numpy as np
from typing import Tuple
# 检查Triton版本
print(f"Triton version: {triton.__version__}")
print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name(0)}")
print(f"Compute Capability: {torch.cuda.get_device_capability(0)}")
# 检查计算能力是否支持FP8
cc = torch.cuda.get_device_capability(0)
fp8_supported = cc[0] * 10 + cc[1] >= 89 # 8.9+
print(f"FP8 hardware support: {fp8_supported}")
@triton.jit
def add_kernel_fp8(
x_ptr, y_ptr, output_ptr,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""使用FP8的向量加法kernel"""
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# 加载FP8数据并转换为FP32进行计算
x = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
y = tl.load(y_ptr + offsets, mask=mask).to(tl.float32)
# 执行计算
output = x + y
# 转换回FP8并存储
# 注意:需要根据实际使用的FP8格式选择合适的缩放因子
output_fp8 = output.to(tl.float8e5) # 或 tl.float8e4m3
tl.store(output_ptr + offsets, output_fp8, mask=mask)
@triton.jit
def matmul_kernel_fp8(
a_ptr, b_ptr, c_ptr,
M, N, K,
stride_am, stride_ak,
stride_bk, stride_bn,
stride_cm, stride_cn,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr, BLOCK_K: tl.constexpr,
):
"""使用FP8的矩阵乘法kernel"""
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
# 计算当前block的位置
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# 创建mask
m_mask = offs_m[:, None] < M
n_mask = offs_n[None, :] < N
# 初始化累加器
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
# 加载FP8矩阵块并转换为FP32
a_ptrs = a_ptr + (offs_m[:, None] * stride_am + (offs_k[None, :] + k) * stride_ak)
b_ptrs = b_ptr + ((offs_k[:, None] + k) * stride_bk + offs_n[None, :] * stride_bn)
a = tl.load(a_ptrs, mask=m_mask & (offs_k[None, :] + k < K)[None, :])
b = tl.load(b_ptrs, mask=(offs_k[:, None] + k < K)[:, None] & n_mask)
# 转换为FP32计算
a_fp32 = a.to(tl.float32)
b_fp32 = b.to(tl.float32)
# 矩阵乘法
accumulator += tl.dot(a_fp32, b_fp32)
# 将结果转换为FP8并存储
c = accumulator.to(tl.float8e5)
# 存储结果
offs_m_full = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n_full = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
c_ptrs = c_ptr + (offs_m_full[:, None] * stride_cm + offs_n_full[None, :] * stride_cn)
tl.store(c_ptrs, c, mask=m_mask & n_mask)
def create_fp8_tensor(data: torch.Tensor, fp8_type: str = 'e5m2') -> torch.Tensor:
"""创建FP8张量"""
if fp8_type == 'e5m2':
# 转换为float8_e5m2格式
# 需要先缩放到合适范围
max_val = data.abs().max()
scale = 448.0 / max_val if max_val > 0 else 1.0 # e5m2最大值为448
scaled_data = data * scale
return scaled_data.to(torch.float8_e5m2)
elif fp8_type == 'e4m3':
max_val = data.abs().max()
scale = 240.0 / max_val if max_val > 0 else 1.0 # e4m3最大值为240
scaled_data = data * scale
return scaled_data.to(torch.float8_e4m3fn)
else:
raise ValueError(f"Unsupported FP8 type: {fp8_type}")
def create_fp8_tensor_with_scaling(data: torch.Tensor, fp8_type: str = 'e5m2'):
"""创建带缩放因子的FP8张量"""
if fp8_type == 'e5m2':
fp8_max = 57344.0
data_max = data.abs().max()
scale = fp8_max / data_max if data_max > 0 else 1.0
scaled_data = data * scale
# 确保值在FP8范围内
scaled_data = torch.clamp(scaled_data, -fp8_max, fp8_max)
fp8_data = scaled_data.to(torch.float8_e5m2)
return fp8_data, scale
elif fp8_type == 'e4m3':
fp8_max = 448.0
data_max = data.abs().max()
scale = fp8_max / data_max if data_max > 0 else 1.0
scaled_data = data * scale
scaled_data = torch.clamp(scaled_data, -fp8_max, fp8_max)
fp8_data = scaled_data.to(torch.float8_e4m3fn)
return fp8_data, scale
else:
raise ValueError(f"Unsupported FP8 type: {fp8_type}")
@triton.jit
def add_kernel_fp8_workaround(
x_ptr, y_ptr, output_ptr,
scale_x, scale_y, scale_out,
n_elements,
BLOCK_SIZE: tl.constexpr,
):
"""绕过类型问题的FP8加法"""
pid = tl.program_id(axis=0)
block_start = pid * BLOCK_SIZE
offsets = block_start + tl.arange(0, BLOCK_SIZE)
mask = offsets < n_elements
# 直接加载并转换,一行完成
# 1. 先加载数据,并立即转换为 float32 类型
x_loaded = tl.load(x_ptr + offsets, mask=mask).to(tl.float32)
# 2. 再对已经明确为 float32 类型的变量进行除法运算
x = x_loaded / scale_x
# 1. 先加载数据,并立即转换为 float32 类型
y_loaded = tl.load(y_ptr + offsets, mask=mask).to(tl.float32)
# 2. 再对已经明确为 float32 类型的变量进行除法运算
y = y_loaded / scale_y
# 计算
output = x + y
# 存储
output_scaled = output * scale_out
tl.store(output_ptr + offsets, output_scaled.to(tl.float8e5), mask=mask)
def test_fp8_vector_addition_fixed():
"""修复后的FP8向量加法测试"""
print("\n" + "="*50)
print("测试: 带缩放因子的FP8向量加法")
print("="*50)
n_elements = 1024
# 准备数据
torch.manual_seed(42)
x = torch.randn(n_elements, device='cuda', dtype=torch.float32) * 2
y = torch.randn(n_elements, device='cuda', dtype=torch.float32) * 2
# 创建带缩放的FP8张量
x_fp8, scale_x = create_fp8_tensor_with_scaling(x, 'e5m2')
y_fp8, scale_y = create_fp8_tensor_with_scaling(y, 'e5m2')
output_fp8 = torch.empty_like(x_fp8)
# 计算输出的缩放因子
expected_max = (x.abs().max() + y.abs().max()).item() * 1.5
scale_out = 57344.0 / expected_max if expected_max > 0 else 1.0
print(f"缩放因子: scale_x={scale_x:.4f}, scale_y={scale_y:.4f}, scale_out={scale_out:.4f}")
print(f"数据范围: x=[{x.min():.3f}, {x.max():.3f}], y=[{y.min():.3f}, {y.max():.3f}]")
# 配置kernel
BLOCK_SIZE = 256
grid = (triton.cdiv(n_elements, BLOCK_SIZE),)
# 执行kernel - 注意缩放因子作为标量传递
add_kernel_fp8_workaround[grid](
x_fp8, y_fp8, output_fp8,
scale_x, scale_y, scale_out,
n_elements, BLOCK_SIZE
)
# 恢复数据
output = output_fp8.float() / scale_out
expected = x + y
# 计算误差
abs_error = (output - expected).abs()
rel_error = abs_error / (expected.abs() + 1e-8)
print(f"\n前10个结果对比:")
print(f"预期: {expected[:10].cpu()}")
print(f"实际: {output[:10].cpu()}")
print(f"差异: {(output - expected)[:10].cpu()}")
print(f"\n误差统计:")
print(f"平均绝对误差: {abs_error.mean():.6f}")
print(f"最大绝对误差: {abs_error.max():.6f}")
print(f"平均相对误差: {rel_error.mean():.6f}")
print(f"最大相对误差: {rel_error.max():.6f}")
# 检查是否合理
if rel_error.mean() < 0.1:
print("✓ FP8精度在可接受范围内")
return True
else:
print("⚠️ FP8精度损失较大")
return False
def test_fp8_matrix_multiplication():
"""测试FP8矩阵乘法"""
print("\n" + "="*50)
print("测试2: FP8矩阵乘法")
print("="*50)
try:
# 配置矩阵维度
M, N, K = 512, 512, 256
# 创建输入矩阵
a = torch.randn((M, K), device='cuda', dtype=torch.float32)
b = torch.randn((K, N), device='cuda', dtype=torch.float32)
# 转换为FP8
a_fp8 = create_fp8_tensor(a, 'e5m2')
b_fp8 = create_fp8_tensor(b, 'e5m2')
c_fp8 = torch.empty((M, N), device='cuda', dtype=torch.float8_e5m2)
# Kernel配置
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 64
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
# 执行矩阵乘法
matmul_kernel_fp8[grid](
a_fp8, b_fp8, c_fp8,
M, N, K,
a.stride(0), a.stride(1),
b.stride(0), b.stride(1),
c_fp8.stride(0), c_fp8.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K
)
# 转换回FP32
c_output = c_fp8.float()
# 验证结果(使用PyTorch的矩阵乘法)
expected = torch.mm(a, b)
# 归一化误差计算
norm_expected = torch.norm(expected)
rel_error = torch.norm(c_output - expected) / norm_expected
print(f"✓ FP8矩阵乘法成功执行")
print(f" 矩阵维度: {M}x{K} * {K}x{N} = {M}x{N}")
print(f" 相对误差: {rel_error:.6f}")
# 显示部分结果
print(f" 预期结果[0:5,0:5]: {expected[0,0]:.4f} {expected[0,1]:.4f} ...")
print(f" 实际结果[0:5,0:5]: {c_output[0,0]:.4f} {c_output[0,1]:.4f} ...")
return True
except Exception as e:
print(f"✗ FP8矩阵乘法失败: {e}")
return False
def test_fp8_dtype_support():
"""测试FP8数据类型支持"""
print("\n" + "="*50)
print("测试3: FP8数据类型支持")
print("="*50)
# 检查torch的FP8支持
fp8_types = {
'float8_e5m2': hasattr(torch, 'float8_e5m2'),
'float8_e4m3fn': hasattr(torch, 'float8_e4m3fn'),
}
for dtype_name, supported in fp8_types.items():
status = "✓" if supported else "✗"
print(f"{status} torch.{dtype_name}: {supported}")
# 检查triton的FP8支持
triton_fp8_types = {
'float8e5': hasattr(tl, 'float8e5'),
'float8e4': hasattr(tl, 'float8e4'),
}
for dtype_name, supported in triton_fp8_types.items():
status = "✓" if supported else "✗"
print(f"{status} tl.{dtype_name}: {supported}")
return any(fp8_types.values()) and any(triton_fp8_types.values())
def performance_comparison():
"""性能对比:FP32 vs FP8"""
print("\n" + "="*50)
print("测试4: 性能对比 (FP32 vs FP8)")
print("="*50)
try:
import time
M, N, K = 1024, 1024, 512
# FP32矩阵乘法
a_fp32 = torch.randn((M, K), device='cuda', dtype=torch.float32)
b_fp32 = torch.randn((K, N), device='cuda', dtype=torch.float32)
# 预热
for _ in range(10):
torch.mm(a_fp32, b_fp32)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
torch.mm(a_fp32, b_fp32)
torch.cuda.synchronize()
fp32_time = (time.time() - start) / 100
# FP8矩阵乘法
a_fp8 = create_fp8_tensor(a_fp32, 'e5m2')
b_fp8 = create_fp8_tensor(b_fp32, 'e5m2')
c_fp8 = torch.empty((M, N), device='cuda', dtype=torch.float8_e5m2)
BLOCK_M, BLOCK_N, BLOCK_K = 64, 64, 64
grid = (triton.cdiv(M, BLOCK_M), triton.cdiv(N, BLOCK_N))
# 预热
for _ in range(10):
matmul_kernel_fp8[grid](
a_fp8, b_fp8, c_fp8,
M, N, K,
a_fp8.stride(0), a_fp8.stride(1),
b_fp8.stride(0), b_fp8.stride(1),
c_fp8.stride(0), c_fp8.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K
)
torch.cuda.synchronize()
start = time.time()
for _ in range(100):
matmul_kernel_fp8[grid](
a_fp8, b_fp8, c_fp8,
M, N, K,
a_fp8.stride(0), a_fp8.stride(1),
b_fp8.stride(0), b_fp8.stride(1),
c_fp8.stride(0), c_fp8.stride(1),
BLOCK_M, BLOCK_N, BLOCK_K
)
torch.cuda.synchronize()
fp8_time = (time.time() - start) / 100
speedup = fp32_time / fp8_time
print(f"FP32 平均时间: {fp32_time*1000:.3f} ms")
print(f"FP8 平均时间: {fp8_time*1000:.3f} ms")
print(f"加速比: {speedup:.2f}x")
return speedup > 1.0
except Exception as e:
print(f"性能测试失败: {e}")
return False
def main():
"""主测试函数"""
print("\n" + "🚀 Triton FP8 支持测试套件")
print("="*50)
# 检查硬件支持
if not torch.cuda.is_available():
print("❌ CUDA不可用,无法测试FP8")
return
cc = torch.cuda.get_device_capability(0)
if cc[0] * 10 + cc[1] < 89:
print(f"⚠️ 警告: GPU计算能力{cc[0]}.{cc[1]} < 8.9")
print(" FP8需要H100、L40S或更新的GPU")
print(" 继续测试但可能会失败...")
# 运行测试
results = {}
# 测试1: 数据类型支持
results['dtype_support'] = test_fp8_dtype_support()
# 测试2: 向量加法
if results['dtype_support']:
results['vector_add'] = test_fp8_vector_addition_fixed()
# 测试3: 矩阵乘法
results['matmul'] = test_fp8_matrix_multiplication()
# 测试4: 性能对比
if results['matmul']:
results['performance'] = performance_comparison()
else:
print("\n❌ FP8数据类型不支持,跳过后续测试")
# 汇总结果
print("\n" + "="*50)
print("📊 测试结果汇总")
print("="*50)
for test_name, passed in results.items():
status = "✅ PASS" if passed else "❌ FAIL"
print(f"{status}: {test_name}")
# 最终结论
print("\n" + "="*50)
if results.get('dtype_support', False):
print("🎉 你的Triton支持FP8!")
if results.get('performance', False):
print("⚡ FP8性能有提升,可以利用FP8加速")
else:
print("⚠️ FP8功能正常但性能提升不明显")
else:
print("❌ 当前环境不支持FP8")
print("\n建议:")
print("1. 升级Triton: pip install --upgrade triton")
print("2. 升级PyTorch: pip install --upgrade torch")
print("3. 确保使用支持的GPU (H100, L40S等)")
print("4. 检查CUDA版本: nvcc --version")
if __name__ == "__main__":
main()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment