Commit c601083d authored by liuys's avatar liuys 🏸
Browse files

update triton

parent 2add9fa3
rm -rf *.db
rm -rf *.csv
rm -rf *.txt
rm -rf *.json
\ No newline at end of file
/***************************************************************************************************
* Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference
matrix multiply kernel to verify its correctness.
The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes
the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes
all matrices have column-major layout.
The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices.
See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available
in CUTLASS.
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
Aside from defining and launching the SGEMM kernel, this example does not use any other components
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
prevalent in the CUTLASS unit tests.
This example has delibrately been kept similar to the basic_gemm example from cutlass-1.3 to
highlight the minimum amount of differences needed to transition to cutlass-2.0.
Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu
*/
// Standard Library includes
#include <iostream>
#include <sstream>
#include <vector>
// Helper methods to check for errors
#include "helper.h"
//
// CUTLASS includes needed for single-precision GEMM kernel
//
// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class.
#include "cutlass/gemm/device/gemm.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object,
// and launches it on the CUDA device.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
cudaError_t CutlassSgemmNN(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
// Define type definition for single-precision CUTLASS GEMM with column-major
// input matrices and 128x128x8 threadblock tile size (chosen by default).
//
// To keep the interface manageable, several helpers are defined for plausible compositions
// including the following example for single-precision GEMM. Typical values are used as
// default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
//
// To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
using ColumnMajor = cutlass::layout::ColumnMajor;
using CutlassGemm = cutlass::gemm::device::Gemm<float, // Data-type of A matrix
ColumnMajor, // Layout of A matrix
float, // Data-type of B matrix
ColumnMajor, // Layout of B matrix
float, // Data-type of C matrix
ColumnMajor>; // Layout of C matrix
// Define a CUTLASS GEMM type
CutlassGemm gemm_operator;
// Construct the CUTLASS GEMM arguments object.
//
// One of CUTLASS's design patterns is to define gemm argument objects that are constructible
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
// and other arguments needed by Gemm and its components.
//
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
//
CutlassGemm::Arguments args({M , N, K}, // Gemm Problem dimensions
{A, lda}, // Tensor-ref for source matrix A
{B, ldb}, // Tensor-ref for source matrix B
{C, ldc}, // Tensor-ref for source matrix C
{C, ldc}, // Tensor-ref for destination matrix D (may be different memory than source C matrix)
{alpha, beta}); // Scalars used in the Epilogue
//
// Launch the CUTLASS GEMM kernel.
//
cutlass::Status status = gemm_operator(args);
//
// Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
//
if (status != cutlass::Status::kSuccess) {
return cudaErrorUnknown;
}
// Return success, if no errors were encountered.
return cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// The source code after this point in the file is generic CUDA using the CUDA Runtime API
// and simple CUDA kernels to initialize matrices and compute the general matrix product.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Kernel to initialize a matrix with small integers.
__global__ void InitializeMatrix_kernel(
float *matrix,
int rows,
int columns,
int seed = 0) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
int j = threadIdx.y + blockIdx.y * blockDim.y;
if (i < rows && j < columns) {
int offset = i + j * rows;
// Generate arbitrary elements.
int const k = 16807;
int const m = 16;
float value = float(((offset + seed) * k % m) - m / 2);
matrix[offset] = value;
}
}
/// Simple function to initialize a matrix to arbitrary small integers.
cudaError_t InitializeMatrix(float *matrix, int rows, int columns, int seed = 0) {
dim3 block(16, 16);
dim3 grid(
(rows + block.x - 1) / block.x,
(columns + block.y - 1) / block.y
);
InitializeMatrix_kernel<<< grid, block >>>(matrix, rows, columns, seed);
return cudaGetLastError();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocates device memory for a matrix then fills with arbitrary small integers.
cudaError_t AllocateMatrix(float **matrix, int rows, int columns, int seed = 0) {
cudaError_t result;
size_t sizeof_matrix = sizeof(float) * rows * columns;
// Allocate device memory.
result = cudaMalloc(reinterpret_cast<void **>(matrix), sizeof_matrix);
if (result != cudaSuccess) {
std::cerr << "Failed to allocate matrix: "
<< cudaGetErrorString(result) << std::endl;
return result;
}
// Clear the allocation.
result = cudaMemset(*matrix, 0, sizeof_matrix);
if (result != cudaSuccess) {
std::cerr << "Failed to clear matrix device memory: "
<< cudaGetErrorString(result) << std::endl;
return result;
}
// Initialize matrix elements to arbitrary small integers.
result = InitializeMatrix(*matrix, rows, columns, seed);
if (result != cudaSuccess) {
std::cerr << "Failed to initialize matrix: "
<< cudaGetErrorString(result) << std::endl;
return result;
}
return result;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Naive reference GEMM computation.
__global__ void ReferenceGemm_kernel(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
int i = threadIdx.x + blockIdx.x * blockDim.x;
int j = threadIdx.y + blockIdx.y * blockDim.y;
if (i < M && j < N) {
float accumulator = 0;
for (int k = 0; k < K; ++k) {
accumulator += A[i + k * lda] * B[k + j * ldb];
}
C[i + j * ldc] = alpha * accumulator + beta * C[i + j * ldc];
}
}
/// Reference GEMM computation.
cudaError_t ReferenceGemm(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
dim3 block(16, 16);
dim3 grid(
(M + block.x - 1) / block.x,
(N + block.y - 1) / block.y
);
ReferenceGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
return cudaGetLastError();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#define TILE_SIZE 16
__global__ void TiledGemm_kernel(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
__shared__ float As[TILE_SIZE][TILE_SIZE];
__shared__ float Bs[TILE_SIZE][TILE_SIZE];
int tx = threadIdx.x;
int ty = threadIdx.y;
int row = blockIdx.y * TILE_SIZE + ty;
int col = blockIdx.x * TILE_SIZE + tx;
float accumulator = 0;
for (int tile = 0; tile < (K + TILE_SIZE - 1) / TILE_SIZE; ++tile) {
if (row < M && (tile * TILE_SIZE + tx) < K) {
As[ty][tx] = A[row + (tile * TILE_SIZE + tx) * lda];
} else {
As[ty][tx] = 0;
}
if (col < N && (tile * TILE_SIZE + ty) < K) {
Bs[ty][tx] = B[(tile * TILE_SIZE + ty) + col * ldb];
} else {
Bs[ty][tx] = 0;
}
__syncthreads();
for (int k = 0; k < TILE_SIZE; ++k) {
accumulator += As[ty][k] * Bs[k][tx];
}
__syncthreads();
}
if (row < M && col < N) {
C[row + col * ldc] = alpha * accumulator + beta * C[row + col * ldc];
}
}
cudaError_t TiledGemm(
int M,
int N,
int K,
float alpha,
float const *A,
int lda,
float const *B,
int ldb,
float beta,
float *C,
int ldc) {
dim3 block(TILE_SIZE, TILE_SIZE);
dim3 grid(
(N + TILE_SIZE - 1) / TILE_SIZE,
(M + TILE_SIZE - 1) / TILE_SIZE
);
TiledGemm_kernel<<< grid, block >>>(M, N, K, alpha, A, lda, B, ldb, beta, C, ldc);
return cudaGetLastError();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocate several matrices in GPU device memory and call a single-precision
/// CUTLASS GEMM kernel.
cudaError_t TestCutlassGemm(int M, int N, int K, float alpha, float beta) {
cudaError_t result;
//
// Define several matrices to be used as operands to GEMM kernels.
//
// Compute leading dimensions for each matrix.
int lda = M;
int ldb = K;
int ldc = M;
// Compute size in bytes of the C matrix.
size_t sizeof_C = sizeof(float) * ldc * N;
// Define pointers to matrices in GPU device memory.
float *A;
float *B;
float *C_cutlass;
float *C_reference;
//
// Allocate matrices in GPU device memory with arbitrary seeds.
//
result = AllocateMatrix(&A, M, K, 0);
if (result != cudaSuccess) {
return result;
}
result = AllocateMatrix(&B, K, N, 17);
if (result != cudaSuccess) {
cudaFree(A);
return result;
}
result = AllocateMatrix(&C_cutlass, M, N, 101);
if (result != cudaSuccess) {
cudaFree(A);
cudaFree(B);
return result;
}
result = AllocateMatrix(&C_reference, M, N, 101);
if (result != cudaSuccess) {
cudaFree(A);
cudaFree(B);
cudaFree(C_cutlass);
return result;
}
result = cudaMemcpy(C_reference, C_cutlass, sizeof_C, cudaMemcpyDeviceToDevice);
if (result != cudaSuccess) {
std::cerr << "Failed to copy C_cutlass matrix to C_reference: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
//
// Launch CUTLASS GEMM.
//
result = CutlassSgemmNN(M, N, K, alpha, A, lda, B, ldb, beta, C_cutlass, ldc);
if (result != cudaSuccess) {
std::cerr << "CUTLASS GEMM kernel failed: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
//
// Verify.
//
// Launch reference GEMM
result = ReferenceGemm(M, N, K, alpha, A, lda, B, ldb, beta, C_reference, ldc);
if (result != cudaSuccess) {
std::cerr << "Reference GEMM kernel failed: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
// Copy to host and verify equivalence.
std::vector<float> host_cutlass(ldc * N, 0);
std::vector<float> host_reference(ldc * N, 0);
result = cudaMemcpy(host_cutlass.data(), C_cutlass, sizeof_C, cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
std::cerr << "Failed to copy CUTLASS GEMM results: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
result = cudaMemcpy(host_reference.data(), C_reference, sizeof_C, cudaMemcpyDeviceToHost);
if (result != cudaSuccess) {
std::cerr << "Failed to copy Reference GEMM results: "
<< cudaGetErrorString(result) << std::endl;
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
return result;
}
//
// Free device memory allocations.
//
cudaFree(C_reference);
cudaFree(C_cutlass);
cudaFree(B);
cudaFree(A);
//
// Test for bit equivalence of results.
//
if (host_cutlass != host_reference) {
std::cerr << "CUTLASS results incorrect." << std::endl;
return cudaErrorUnknown;
}
return cudaSuccess;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Entry point to basic_gemm example.
//
// usage:
//
// 00_basic_gemm <M> <N> <K> <alpha> <beta>
//
int main(int argc, const char *arg[]) {
//
// Parse the command line to obtain GEMM dimensions and scalar values.
//
// GEMM problem dimensions.
int problem[3] = { 128, 128, 128 };
for (int i = 1; i < argc && i < 4; ++i) {
std::stringstream ss(arg[i]);
ss >> problem[i - 1];
}
// Scalars used for linear scaling the result of the matrix product.
float scalars[2] = { 1, 0 };
for (int i = 4; i < argc && i < 6; ++i) {
std::stringstream ss(arg[i]);
ss >> scalars[i - 4];
}
//
// Run the CUTLASS GEMM test.
//
cudaError_t result = TestCutlassGemm(
problem[0], // GEMM M dimension
problem[1], // GEMM N dimension
problem[2], // GEMM K dimension
scalars[0], // alpha
scalars[1] // beta
);
if (result == cudaSuccess) {
std::cout << "Passed." << std::endl;
}
// Exit.
return result == cudaSuccess ? 0 : -1;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
\ No newline at end of file
#include <cuda_runtime.h>
#include <iostream>
#include <cmath>
#include <cstdlib>
#define CHECK_CUDA(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " - " \
<< cudaGetErrorString(err) << std::endl; \
exit(EXIT_FAILURE); \
} \
} while (0)
constexpr int kMmaM = 16;
constexpr int kMmaN = 16;
constexpr int kMmaK = 16;
constexpr int kWarpM = 64;
constexpr int kWarpN = 64;
constexpr int kWarpK = 32;
constexpr int kBlockM = 128;
constexpr int kBlockN = 128;
constexpr int kBlockK = 64;
constexpr int kWarpNumM = kBlockM / kWarpM;
constexpr int kWarpNumN = kBlockN / kWarpN;
__global__ void TiledGemmKernel(
int M, int N, int K,
float alpha,
const float* __restrict__ A,
const float* __restrict__ B,
float beta,
float* __restrict__ C) {
const int lda = M;
const int ldb = K;
const int ldc = M;
__shared__ float smemA[kBlockM][kBlockK];
__shared__ float smemB[kBlockK][kBlockN];
const int warpId = threadIdx.x / 32;
const int laneId = threadIdx.x % 32;
const int warpRow = warpId / kWarpNumN;
const int warpCol = warpId % kWarpNumN;
// 每个线程负责4x4的碎片计算
const int threadRowInWarp = laneId / 4;
const int threadColInWarp = laneId % 4;
const int blockRow = blockIdx.y * kBlockM;
const int blockCol = blockIdx.x * kBlockN;
// 每个线程负责4x4的结果,所以每个warp负责64x64
float acc[4][4] = {0};
const int numTiles = (K + kBlockK - 1) / kBlockK;
for (int tileIdx = 0; tileIdx < numTiles; ++tileIdx) {
// 加载A到共享内存 (M维度分块)
for (int i = threadIdx.x; i < kBlockM * kBlockK; i += blockDim.x) {
int row = i / kBlockK;
int col = i % kBlockK;
int globalRow = blockRow + row;
int globalCol = tileIdx * kBlockK + col;
if (globalRow < M && globalCol < K) {
smemA[row][col] = A[globalRow + globalCol * lda];
} else {
smemA[row][col] = 0.0f;
}
}
// 加载B到共享内存 (N维度分块)
for (int i = threadIdx.x; i < kBlockK * kBlockN; i += blockDim.x) {
int row = i / kBlockN;
int col = i % kBlockN;
int globalRow = tileIdx * kBlockK + row;
int globalCol = blockCol + col;
if (globalRow < K && globalCol < N) {
smemB[row][col] = B[globalRow + globalCol * ldb];
} else {
smemB[row][col] = 0.0f;
}
}
__syncthreads();
// 计算当前tile
const int warpStartRow = warpRow * kWarpM;
const int warpStartCol = warpCol * kWarpN;
for (int k = 0; k < kBlockK; k += kMmaK) {
// 每个线程加载4个A的元素
float aFrag[4];
#pragma unroll
for (int i = 0; i < 4; ++i) {
int row = warpStartRow + threadRowInWarp + i * 4;
int col = k + threadColInWarp;
aFrag[i] = smemA[row][col];
}
// 每个线程加载4个B的元素
float bFrag[4];
#pragma unroll
for (int j = 0; j < 4; ++j) {
int row = k + threadRowInWarp;
int col = warpStartCol + threadColInWarp + j * 4;
bFrag[j] = smemB[row][col];
}
// 计算外积并累加
#pragma unroll
for (int i = 0; i < 4; ++i) {
#pragma unroll
for (int j = 0; j < 4; ++j) {
acc[i][j] += aFrag[i] * bFrag[j];
}
}
}
__syncthreads();
}
// 写回结果
const int warpStartRow = blockRow + warpRow * kWarpM;
const int warpStartCol = blockCol + warpCol * kWarpN;
for (int i = 0; i < 4; ++i) {
int row = warpStartRow + threadRowInWarp + i * 4;
if (row >= M) continue;
for (int j = 0; j < 4; ++j) {
int col = warpStartCol + threadColInWarp + j * 4;
if (col >= N) continue;
int idx = row + col * ldc;
C[idx] = alpha * acc[i][j] + beta * C[idx];
}
}
}
void TiledGemm(
int M, int N, int K,
float alpha,
const float* A,
const float* B,
float beta,
float* C) {
dim3 block(256);
dim3 grid(
(N + kBlockN - 1) / kBlockN,
(M + kBlockM - 1) / kBlockM
);
TiledGemmKernel<<<grid, block>>>(M, N, K, alpha, A, B, beta, C);
CHECK_CUDA(cudaDeviceSynchronize());
}
void ReferenceGemm(
int M, int N, int K,
float alpha,
const float* A,
const float* B,
float beta,
float* C) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
float sum = 0;
for (int k = 0; k < K; ++k) {
sum += A[i + k * M] * B[k + j * K];
}
C[i + j * M] = alpha * sum + beta * C[i + j * M];
}
}
}
void RandomInit(float* data, int size) {
for (int i = 0; i < size; ++i) {
data[i] = (float(rand()) / RAND_MAX) * 2.0f - 1.0f;
}
}
bool Verify(const float* C1, const float* C2, int M, int N, float tolerance = 1e-3f) {
for (int i = 0; i < M; ++i) {
for (int j = 0; j < N; ++j) {
float diff = fabsf(C1[i + j * M] - C2[i + j * M]);
if (diff > tolerance) {
std::cerr << "Mismatch at C[" << i << "," << j << "]: "
<< C1[i + j * M] << " vs " << C2[i + j * M]
<< " (diff=" << diff << ")" << std::endl;
return false;
}
}
}
return true;
}
int main(int argc, char** argv) {
int M = 512;
int N = 512;
int K = 512;
float alpha = 1.0f;
float beta = 0.0f;
if (argc >= 4) {
M = atoi(argv[1]);
N = atoi(argv[2]);
K = atoi(argv[3]);
}
std::cout << "GEMM: M=" << M << ", N=" << N << ", K=" << K << std::endl;
float *h_A, *h_B, *h_C_tiled, *h_C_ref;
float *d_A, *d_B, *d_C;
h_A = new float[M * K];
h_B = new float[K * N];
h_C_tiled = new float[M * N];
h_C_ref = new float[M * N];
RandomInit(h_A, M * K);
RandomInit(h_B, K * N);
CHECK_CUDA(cudaMalloc(&d_A, M * K * sizeof(float)));
CHECK_CUDA(cudaMalloc(&d_B, K * N * sizeof(float)));
CHECK_CUDA(cudaMalloc(&d_C, M * N * sizeof(float)));
CHECK_CUDA(cudaMemcpy(d_A, h_A, M * K * sizeof(float), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemcpy(d_B, h_B, K * N * sizeof(float), cudaMemcpyHostToDevice));
CHECK_CUDA(cudaMemset(d_C, 0, M * N * sizeof(float)));
cudaEvent_t start, stop;
CHECK_CUDA(cudaEventCreate(&start));
CHECK_CUDA(cudaEventCreate(&stop));
CHECK_CUDA(cudaEventRecord(start));
TiledGemm(M, N, K, alpha, d_A, d_B, beta, d_C);
CHECK_CUDA(cudaEventRecord(stop));
CHECK_CUDA(cudaEventSynchronize(stop));
float milliseconds = 0;
CHECK_CUDA(cudaEventElapsedTime(&milliseconds, start, stop));
CHECK_CUDA(cudaMemcpy(h_C_tiled, d_C, M * N * sizeof(float), cudaMemcpyDeviceToHost));
ReferenceGemm(M, N, K, alpha, h_A, h_B, beta, h_C_ref);
bool passed = Verify(h_C_tiled, h_C_ref, M, N);
float tflops = (2.0f * M * N * K) / (milliseconds * 1e-3f) / 1e12f;
std::cout << "Tiled GEMM: " << milliseconds << " ms" << std::endl;
std::cout << "Performance: " << tflops << " TFLOPS" << std::endl;
std::cout << "Result: " << (passed ? "PASSED" : "FAILED") << std::endl;
CHECK_CUDA(cudaEventDestroy(start));
CHECK_CUDA(cudaEventDestroy(stop));
CHECK_CUDA(cudaFree(d_A));
CHECK_CUDA(cudaFree(d_B));
CHECK_CUDA(cudaFree(d_C));
delete[] h_A;
delete[] h_B;
delete[] h_C_tiled;
delete[] h_C_ref;
return passed ? 0 : 1;
}
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
# 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回)
def grab_first_if_tuple(x):
return x[0] if isinstance(x, tuple) else x
class ParallelGatedMLP(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
multiple_of = config.get("inner_size_multiple_of", 64)
self.act_type = config.get("mlp_activation", "gelu")
if self.act_type == "gelu":
self.act = F.gelu
elif self.act_type == "silu":
self.act = F.silu
else:
raise NotImplementedError
if self.layer_idx > 0 and config.get("evo2_style_activations", False):
self.act = nn.Identity()
inner_size = 11264
self.l1 = nn.Linear(
in_features=config.get("hidden_size", 4096),
out_features=inner_size,
bias=False,
)
self.l2 = nn.Linear(
in_features=config.get("hidden_size", 4096),
out_features=inner_size,
bias=False,
)
self.l3 = nn.Linear(
in_features=inner_size,
out_features=config.get("hidden_size", 4096),
bias=False,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous())
self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous())
self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous())
def forward(self, z):
z1, z2 = self.l1(z), self.l2(z)
return z1, z2
# === 示例调用 ===
if __name__ == "__main__":
# 模拟配置
config = {
"hidden_size": 4096,
"mlp_activation": "silu",
"model_parallel_size:q": 1,
"evo2_style_activations": False,
}
layer_idx = 0
# 创建模型实例
model = ParallelGatedMLP(config, layer_idx)
# 将模型转换为 bfloat16
model = model.to(dtype=torch.bfloat16, device="cuda:0")
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device = "cuda:0" # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)
# 推理
with torch.no_grad():
output = model(x)
\ No newline at end of file
import triton
import triton.language as tl
# @triton.jit
# def gated_mlp_kernel(
# # 输入
# x_ptr, # [M, K]
# w1_ptr, # [N, K] -> 注意:w1 是 out_features x in_features
# w2_ptr,
# w3_ptr, # [K_out, N] = [hidden, inner]
# y_ptr, # output [M, K_out]
# # 形状
# M, # batch * seq_len
# K, # hidden_size (e.g., 4096)
# N, # inner_size (e.g., 11264)
# K_out: tl.constexpr,
# # 分块
# BLOCK_M: tl.constexpr = 64,
# BLOCK_N: tl.constexpr = 128,
# BLOCK_K: tl.constexpr = 64,
# ):
# pid_m = tl.program_id(0)
# pid_n = tl.program_id(1)
# # 计算当前 block 覆盖的输出区域: [pid_m*BLOCK_M : ..., pid_n*BLOCK_N : ...]
# offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
# offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# offs_k = tl.arange(0, BLOCK_K)
# # 加载 x 的一行(或几行)
# x_ptrs = x_ptr + offs_m[:, None] * K + offs_k[None, :]
# w1_ptrs = w1_ptr + offs_n[:, None] * K + offs_k[None, :]
# w2_ptrs = w2_ptr + offs_n[:, None] * K + offs_k[None, :]
# # 初始化累加器
# acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# for k in range(0, K, BLOCK_K):
# # 边界处理
# k_mask = (offs_k[None, :] < K - k)
# x = tl.load(x_ptrs, mask=k_mask, other=0.0)
# w1 = tl.load(w1_ptrs, mask=k_mask, other=0.0)
# w2 = tl.load(w2_ptrs, mask=k_mask, other=0.0)
# acc1 += tl.dot(x, w1.T)
# acc2 += tl.dot(x, w2.T)
# x_ptrs += BLOCK_K
# w1_ptrs += BLOCK_K
# w2_ptrs += BLOCK_K
# offs_k += BLOCK_K
# # 应用 SiLU: x * sigmoid(x)
# z1 = acc1.to(tl.bfloat16)
# z2 = acc2.to(tl.bfloat16)
# sig = tl.sigmoid(z1)
# gated = z1 * sig * z2 # SiLU(z1) * z2
# # 第二阶段:gated @ w3.T → [M, N] @ [K_out, N].T = [M, K_out]
# # 注意:w3 是 [K_out, N],我们要做 gated (M,N) × w3.T (N, K_out)
# offs_k2 = tl.arange(0, BLOCK_K)
# w3_ptrs = w3_ptr + offs_n[:, None] + offs_k2[None, :] * N # w3[k_out, n] → 列主序?
# # 更安全的方式:假设 w3 是 [K_out, N],按行存储,则 w3[k, n] = w3_ptr[k*N + n]
# # 所以要加载 w3 的第 n 列 → 需要转置视角
# # 我们改用:对每个输出列 k_out,累加 gated[:, n] * w3[k_out, n]
# # 所以启动 grid 时,pid_n 对应 k_out,需要调整逻辑
# # ⚠️ 上面的设计有问题!更好的方式是分两个 kernel:
# # 1. 计算 gated = SiLU(x@W1) * (x@W2) → [M, N]
# # 2. gated @ W3.T → [M, K_out]
# # 因为 N=11264 很大,直接三重融合会导致寄存器溢出
# # 因此,我们只融合前两步 + activation,第三步用 cuBLAS(torch.matmul)
@triton.jit
def gated_proj_kernel(
x_ptr, w1_ptr, w2_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_wk, stride_wn, # w is [N, K], so stride_wn = K
stride_om, stride_on,
ACTIVATION: tl.constexpr,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
k_mask = offs_k[None, :] < K - k
x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0)
w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
acc1 += tl.dot(x, w1.T)
acc2 += tl.dot(x, w2.T)
x_ptrs += BLOCK_K * stride_xk
w1_ptrs += BLOCK_K * stride_wk
w2_ptrs += BLOCK_K * stride_wk
offs_k += BLOCK_K
z1 = acc1.to(tl.bfloat16)
z2 = acc2.to(tl.bfloat16)
if ACTIVATION == "silu":
sig = tl.sigmoid(z1)
out = z1 * sig * z2
elif ACTIVATION == "gelu":
# Triton 没有 gelu,可近似或回退
out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
else:
out = z1 * z2
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
if __name__ == "__main__":
# 模拟配置
config = {
"hidden_size": 4096,
"mlp_activation": "silu",
"model_parallel_size:q": 1,
"evo2_style_activations": False,
}
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
if __name__ == "__main__":
pass
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import numpy as np
import random
import time
@triton.jit
def gated_proj_kernel(
x_ptr, w1_ptr, w2_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_w1k, stride_w1n, # w1 is [K, N]
stride_w2k, stride_w2n, # w2 is [K, N]
stride_om, stride_on,
ACTIVATION: tl.constexpr,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# x: [M, K]
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
# w1 and w2: [K, N] (转置后的权重)
# 注意:w1_ptr 和 w2_ptr 已经指向转置后的权重
w1_ptrs = w1_ptr + offs_k[:, None] * stride_w1k + offs_n[None, :] * stride_w1n
w2_ptrs = w2_ptr + offs_k[:, None] * stride_w2k + offs_n[None, :] * stride_w2n
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
# 加载 x
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
# 加载 w1 和 w2
w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N)
w1 = tl.load(w1_ptrs, mask=w_mask, other=0.0)
w2 = tl.load(w2_ptrs, mask=w_mask, other=0.0)
# 计算点积: x @ w1^T 和 x @ w2^T
# x: [BLOCK_M, BLOCK_K], w1: [BLOCK_K, BLOCK_N]
# tl.dot(x, w1) 计算的是 x @ w1,但我们需要 x @ w1^T
# 由于 w1 是转置后的权重 [K, N],所以 x @ w1 就是我们要的 x @ w1^T
acc1 += tl.dot(x, w1)
acc2 += tl.dot(x, w2)
# 移动指针到下一个block
x_ptrs += BLOCK_K * stride_xk
w1_ptrs += BLOCK_K * stride_w1k
w2_ptrs += BLOCK_K * stride_w2k
# 应用激活函数
if ACTIVATION == "silu":
# SiLU(x) = x * sigmoid(x)
sig = tl.sigmoid(acc1)
out = acc1 * sig * acc2 # SiLU(w1*x) * (w2*x)
# elif ACTIVATION == "gelu":
# # GELU 近似
# # GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
# gelu_approx = 0.5 * acc1 * (1 + tl.tanh(0.79788456 * (acc1 + 0.044715 * acc1 * acc1 * acc1)))
# out = gelu_approx * acc2
# else:
# # 无激活函数
# out = acc1 * acc2
# 存储结果
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def fused_gated_proj(x, w1, w2, activation="silu"):
"""
x: [M, K] - input
w1: [N, K] - weight1 (PyTorch Linear weight, 形状为 [out_features, in_features])
w2: [N, K] - weight2 (PyTorch Linear weight, 形状为 [out_features, in_features])
返回: [M, N]
计算: activation(w1 @ x^T)^T * (w2 @ x^T)^T
等价于: SiLU(x @ w1^T) * (x @ w2^T)
"""
assert x.dtype == torch.bfloat16
assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16
M, K = x.shape # M=1, K=4096
N, K2 = w1.shape # N=4096 K2=11264
assert K == K2, f"Dimension mismatch: x K={K}, w1 K={K2}"
assert w2.shape == (N, K), f"w2 shape mismatch: expected {(N, K)}, got {w2.shape}"
# 提前转置权重到 [K, N] 格式
w1_t = w1.t().contiguous() # [K, N]
w2_t = w2.t().contiguous() # [K, N]
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N'])
)
gated_proj_kernel[grid](
x, w1_t, w2_t, out, # 传入转置后的权重
M, K, N,
x.stride(0), x.stride(1),
w1_t.stride(0), w1_t.stride(1), # [K, N] 的 stride
w2_t.stride(0), w2_t.stride(1),
out.stride(0), out.stride(1),
ACTIVATION=activation,
BLOCK_M=64,
BLOCK_N=64,
BLOCK_K=32,
)
return out
class ParallelGatedMLP(nn.Module):
def __init__(self):
super().__init__()
self.act = F.silu
self.act_type = "silu"
self.l1 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l2 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l3 = nn.Linear(
in_features=11264,
out_features=4096,
bias=False,
)
def forward_org(self, z):
"""原始实现"""
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
# PyTorch: F.linear(x, weight) = x @ weight^T
# z1 = F.linear(z_flat, self.l1.weight) # [M, N]
# z2 = F.linear(z_flat, self.l2.weight) # [M, N]
z1, z2 = self.l1(z_flat), self.l2(z_flat)
gated = self.act(z1) * z2
return gated
def forward_opt(self, z):
"""Triton优化实现"""
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
# Triton 路径
gated = fused_gated_proj(
z_flat,
self.l1.weight, # [N, K]
self.l2.weight, # [N, K]
activation=self.act_type
)
return gated
if __name__ == "__main__":
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 创建模型实例
model = ParallelGatedMLP()
model = model.to(dtype=torch.bfloat16, device="cuda:0")
# 测试不同的batch size
for batch_size in [1]:
print(f"\n{'='*50}")
print(f"Testing with batch_size={batch_size}")
print('='*50)
x = torch.randn(batch_size, 1, 4096, dtype=torch.bfloat16, device="cuda:0")
with torch.no_grad():
# 预热
for _ in range(3):
_ = model.forward_org(x)
_ = model.forward_opt(x)
t0 = time.time()
# 计算原始版本
for i in range(10):
result_org = model.forward_org(x)
t1 = time.time()
print(f"Time taken for forward_org: {t1 - t0:5f} seconds")
# 计算优化版本
for i in range(10):
result_opt = model.forward_opt(x)
print(f"Time taken for forward_opt: {time.time() - t1:.5f} seconds")
# # 验证结果
# print(f"ORG shape: {result_org.shape}")
# print(f"OPT shape: {result_opt.shape}")
# # 计算差异
# diff = torch.abs(result_org - result_opt)
# print(f"Max diff: {diff.max().item():.6f}")
# print(f"Mean diff: {diff.mean().item():.6f}")
# print(f"Min diff: {diff.min().item():.6f}")
# # 相对误差
# rel_error = diff / (torch.abs(result_org) + 1e-8)
# print(f"Max relative error: {rel_error.max().item():.6f}")
# print(f"Mean relative error: {rel_error.mean().item():.6f}")
# # 验证前几个值
# print("\nFirst 10 values comparison:")
# print(f"ORG: {result_org[0, :10].float().cpu().numpy()}")
# print(f"OPT: {result_opt[0, :10].float().cpu().numpy()}")
# print(f"Diff: {diff[0, :10].float().cpu().numpy()}")
# # 检查是否匹配
# if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3):
# print("✓ Results match within tolerance!")
# else:
# print("✗ Results do not match!")
# # 额外的验证:检查数学等价性
# print(f"\n{'='*50}")
# print("Additional mathematical verification")
# print('='*50)
# # 使用小矩阵验证
# test_x = torch.randn(2, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w1 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w2 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# # PyTorch 计算
# z1_pt = F.linear(test_x, test_w1) # x @ w1^T
# z2_pt = F.linear(test_x, test_w2) # x @ w2^T
# result_pt = F.silu(z1_pt) * z2_pt
# # Triton 计算
# result_triton = fused_gated_proj(test_x, test_w1, test_w2, activation="silu")
# diff_test = torch.abs(result_pt - result_triton)
# print(f"Test max diff: {diff_test.max().item():.6f}")
# print(f"Test mean diff: {diff_test.mean().item():.6f}")
# if torch.allclose(result_pt, result_triton, rtol=1e-2, atol=1e-3):
# print("✓ Test passed: Triton implementation matches PyTorch!")
# else:
# print("✗ Test failed: Triton implementation doesn't match PyTorch!")
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import time
import torch
import numpy as np
import random
@triton.jit
def gated_proj_kernel(
x_ptr, w1_ptr, w2_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_wk, stride_wn, # w is [N, K], so stride_wn = K
stride_om, stride_on,
ACTIVATION: tl.constexpr,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
k_mask = offs_k[None, :] < K - k
x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0)
w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
acc1 += tl.dot(x, w1.T)
acc2 += tl.dot(x, w2.T)
x_ptrs += BLOCK_K * stride_xk
w1_ptrs += BLOCK_K * stride_wk
w2_ptrs += BLOCK_K * stride_wk
offs_k += BLOCK_K
z1 = acc1.to(tl.float32)
z2 = acc2.to(tl.float32)
if ACTIVATION == "silu":
sig = tl.sigmoid(z1)
out = z1 * sig * z2
elif ACTIVATION == "gelu":
# Triton 没有 gelu,可近似或回退
# out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
sig = tl.sigmoid(z1)
out = z1 * sig * z2
else:
out = z1 * z2
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def fused_gated_proj(x, w1, w2, activation="silu"):
assert x.dtype == torch.bfloat16
assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16
M, K = x.shape # 1, 4096
N, _ = w1.shape # 4096, 11264
assert w2.shape == (N, K)
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N'])
)
gated_proj_kernel[grid](
x, w1, w2, out,
M, K, N,
x.stride(0), x.stride(1),
w1.stride(1), w1.stride(0),
out.stride(0), out.stride(1),
ACTIVATION=activation,
BLOCK_M=64,
BLOCK_N=64,
BLOCK_K=32,
)
return out
class ParallelGatedMLP(nn.Module):
def __init__(self):
super().__init__()
self.act = F.silu
self.act_type = "silu"
self.l1 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l2 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l3 = nn.Linear(
in_features=11264,
out_features=4096,
bias=False,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous())
self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous())
self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous())
def forward(self, z):
# z: [B, S, D] → flatten to [M, D]
shape = z.shape
z_flat = z.view(-1, int(shape[-1])) # [M, D]
# Triton 路径
gated = fused_gated_proj(
z_flat,
self.l1.weight, # [inner, hidden]
self.l2.weight,
activation=self.act_type
)
# y_flat = self.l3(gated) # [M, D]
# y = y_flat.view(*shape)
return gated
def forward_org(self, z):
shape = z.shape
z_flat = z.view(-1, shape[-1])
# GELU 或调试时走原生路径
z1, z2 = self.l1(z_flat), self.l2(z_flat)
gated = self.act(z1) * z2
return gated
def forward_opt(self, z):
# z: [B, S, D] → flatten to [M, D]
shape = z.shape
z_flat = z.view(-1, int(shape[-1])) # [M, D]
# Triton 路径
gated = fused_gated_proj(
z_flat,
self.l1.weight, # [inner, hidden]
self.l2.weight,
activation=self.act_type
)
return gated
if __name__ == "__main__":
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if using multi-GPU
np.random.seed(seed)
random.seed(seed)
# 可选:牺牲性能以换取可复现性(因为某些 CUDA 操作是非确定性的)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 创建模型实例
model = ParallelGatedMLP()
# 将模型转换为 bfloat16
model = model.to(dtype=torch.bfloat16, device="cuda:0")
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device = "cuda:0" # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)
with torch.no_grad():
result_org = model.forward_org(x)
print(f"ORG: {result_org[0, :20]}")
result_opt = model.forward_opt(x)
print(f"OPT: {result_opt[0, :20]}")
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
# 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回)
def grab_first_if_tuple(x):
return x[0] if isinstance(x, tuple) else x
class ParallelGatedMLP(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
multiple_of = config.get("inner_size_multiple_of", 64)
self.act_type = config.get("mlp_activation", "gelu")
if self.act_type == "gelu":
self.act = F.gelu
elif self.act_type == "silu":
self.act = F.silu
else:
raise NotImplementedError
if self.layer_idx > 0 and config.get("evo2_style_activations", False):
self.act = nn.Identity()
inner_size = 11264
self.l1 = nn.Linear(
in_features=config.get("hidden_size", 4096),
out_features=inner_size,
bias=False,
)
self.l2 = nn.Linear(
in_features=config.get("hidden_size", 4096),
out_features=inner_size,
bias=False,
)
self.l3 = nn.Linear(
in_features=inner_size,
out_features=config.get("hidden_size", 4096),
bias=False,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous())
self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous())
self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous())
def forward(self, z):
z1, z2 = self.l1(z), self.l2(z)
return z1, z2
# === 示例调用 ===
if __name__ == "__main__":
# 模拟配置
config = {
"hidden_size": 4096,
"mlp_activation": "silu",
"model_parallel_size:q": 1,
"evo2_style_activations": False,
}
layer_idx = 0
# 创建模型实例
model = ParallelGatedMLP(config, layer_idx)
# 将模型转换为 bfloat16
model = model.to(dtype=torch.bfloat16, device="cuda:0")
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device = "cuda:0" # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)
# 推理
with torch.no_grad():
for i in range(10):
output = model(x)
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import numpy as np
import random
import time
@triton.jit
def matmul_kernel(
x_ptr, w_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_wk, stride_wn, # w is [K, N] (已经转置好)
stride_om, stride_on,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
w_ptrs = w_ptr + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k)
w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
w = tl.load(w_ptrs, mask=w_mask, other=0.0)
acc += tl.dot(x, w)
x_ptrs += BLOCK_K * stride_xk
w_ptrs += BLOCK_K * stride_wk
# 转换为bfloat16输出
out = acc.to(tl.bfloat16)
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def triton_matmul(x, weight):
"""
Compute y = x @ weight.T using Triton.
x: [M, K], dtype=bfloat16
weight: [N, K], dtype=bfloat16 (PyTorch Linear weight, 形状是[out_features, in_features])
Returns: y: [M, N], dtype=bfloat16
"""
assert x.dtype == torch.bfloat16
assert weight.dtype == torch.bfloat16
assert x.device == weight.device
assert x.is_contiguous()
M, K = x.shape
N, K2 = weight.shape
assert K == K2, f"K mismatch: {K} != {K2}"
# 提前转置权重到[K, N]格式,这样triton kernel可以直接使用
# weight是[N, K],我们需要weight.T = [K, N]
w_t = weight.t().contiguous() # [K, N]
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N'])
)
# 注意:这里传递的是转置后的权重w_t,形状是[K, N]
matmul_kernel[grid](
x, w_t, out,
M, K, N,
x.stride(0), x.stride(1),
w_t.stride(0), w_t.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=64,
BLOCK_N=64,
BLOCK_K=32,
)
return out
class ParallelGatedMLP(nn.Module):
def __init__(self):
super().__init__()
self.act = F.silu
self.act_type = "silu"
self.l1 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l2 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l3 = nn.Linear(
in_features=11264,
out_features=4096,
bias=False,
)
def forward_org(self, z):
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
y = F.linear(z_flat, self.l1.weight, bias=None) # [M, N]
return y
def forward_org_triton(self, z):
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
y = triton_matmul(z_flat, self.l1.weight) # [M, N]
return y
if __name__ == "__main__":
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model = ParallelGatedMLP()
model = model.to(dtype=torch.bfloat16, device="cuda:0")
device = "cuda:0"
x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)
# 测试正确性
with torch.no_grad():
result_org = model.forward_org(x)
result_opt = model.forward_org_triton(x)
print(f"ORG shape: {result_org.shape}")
print(f"OPT shape: {result_opt.shape}")
# 打印前20个元素比较
print(f"\nORG first 20: {result_org[0, :20]}")
print(f"OPT first 20: {result_opt[0, :20]}")
# 计算差异
diff = torch.abs(result_org - result_opt)
print(f"\nMax diff: {diff.max().item()}")
print(f"Mean diff: {diff.mean().item()}")
# 相对误差检查
rel_error = diff / (torch.abs(result_org) + 1e-8)
print(f"Max relative error: {rel_error.max().item()}")
# 验证是否在合理误差范围内(由于浮点计算差异)
if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3):
print("\n✓ Results match within tolerance!")
else:
print("\n✗ Results do not match!")
\ No newline at end of file
export ROCBLAS_LAYER=3
python trition_opt.py
\ No newline at end of file
rm -rf *.db
rm -rf *.csv
rm -rf *.txt
rm -rf *.json
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import numpy as np
import random
import time
@triton.jit
def matmul_kernel(
x_ptr, w_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_wk, stride_wn, # w is [K, N] (已经转置好)
stride_om, stride_on,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
w_ptrs = w_ptr + offs_k[:, None] * stride_wk + offs_n[None, :] * stride_wn
acc = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k)
w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
w = tl.load(w_ptrs, mask=w_mask, other=0.0)
acc += tl.dot(x, w)
x_ptrs += BLOCK_K * stride_xk
w_ptrs += BLOCK_K * stride_wk
# 转换为bfloat16输出
out = acc.to(tl.bfloat16)
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out, mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def triton_matmul(x, weight):
"""
Compute y = x @ weight.T using Triton.
x: [M, K], dtype=bfloat16
weight: [N, K], dtype=bfloat16 (PyTorch Linear weight, 形状是[out_features, in_features])
Returns: y: [M, N], dtype=bfloat16
"""
assert x.dtype == torch.bfloat16
assert weight.dtype == torch.bfloat16
assert x.device == weight.device
assert x.is_contiguous()
M, K = x.shape
N, K2 = weight.shape
assert K == K2, f"K mismatch: {K} != {K2}"
# 提前转置权重到[K, N]格式,这样triton kernel可以直接使用
# weight是[N, K],我们需要weight.T = [K, N]
w_t = weight.t().contiguous() # [K, N]
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N'])
)
# 注意:这里传递的是转置后的权重w_t,形状是[K, N]
matmul_kernel[grid](
x, w_t, out,
M, K, N,
x.stride(0), x.stride(1),
w_t.stride(0), w_t.stride(1),
out.stride(0), out.stride(1),
BLOCK_M=64,
BLOCK_N=64,
BLOCK_K=32,
)
return out
class ParallelGatedMLP(nn.Module):
def __init__(self):
super().__init__()
self.act = F.silu
self.act_type = "silu"
self.l1 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l2 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l3 = nn.Linear(
in_features=11264,
out_features=4096,
bias=False,
)
def forward_org(self, z):
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
# bfloat16 数据精度
# self.l1 = nn.Linear(
# in_features=4096,
# out_features=11264,
# bias=False,
# )
# z_flat.shape 1,4096
y = F.linear(z_flat, self.l1.weight, bias=None) # [M, N]
return y
def forward_org_triton(self, z):
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
y = triton_matmul(z_flat, self.l1.weight) # [M, N]
return y
if __name__ == "__main__":
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
model = ParallelGatedMLP()
model = model.to(dtype=torch.bfloat16, device="cuda:0")
device = "cuda:0"
x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)
# 测试正确性
with torch.no_grad():
result_org = model.forward_org(x)
result_opt = model.forward_org_triton(x)
print(f"ORG shape: {result_org.shape}")
print(f"OPT shape: {result_opt.shape}")
# 打印前20个元素比较
print(f"\nORG first 20: {result_org[0, :20]}")
print(f"OPT first 20: {result_opt[0, :20]}")
# 计算差异
diff = torch.abs(result_org - result_opt)
print(f"\nMax diff: {diff.max().item()}")
print(f"Mean diff: {diff.mean().item()}")
# 相对误差检查
rel_error = diff / (torch.abs(result_org) + 1e-8)
print(f"Max relative error: {rel_error.max().item()}")
# 验证是否在合理误差范围内(由于浮点计算差异)
if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3):
print("\n✓ Results match within tolerance!")
else:
print("\n✗ Results do not match!")
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import numpy as np
import random
import time
@triton.jit
def gated_proj_kernel(
x_ptr, w1_ptr, w2_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_w1k, stride_w1n, # w1 is [K, N]
stride_w2k, stride_w2n, # w2 is [K, N]
stride_om, stride_on,
ACTIVATION: tl.constexpr,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
# x: [M, K]
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
# w1 and w2: [K, N] (转置后的权重)
# 注意:w1_ptr 和 w2_ptr 已经指向转置后的权重
w1_ptrs = w1_ptr + offs_k[:, None] * stride_w1k + offs_n[None, :] * stride_w1n
w2_ptrs = w2_ptr + offs_k[:, None] * stride_w2k + offs_n[None, :] * stride_w2n
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
# 加载 x
x_mask = (offs_m[:, None] < M) & (offs_k[None, :] < K - k)
x = tl.load(x_ptrs, mask=x_mask, other=0.0)
# 加载 w1 和 w2
w_mask = (offs_k[:, None] < K - k) & (offs_n[None, :] < N)
w1 = tl.load(w1_ptrs, mask=w_mask, other=0.0)
w2 = tl.load(w2_ptrs, mask=w_mask, other=0.0)
# 计算点积: x @ w1^T 和 x @ w2^T
# x: [BLOCK_M, BLOCK_K], w1: [BLOCK_K, BLOCK_N]
# tl.dot(x, w1) 计算的是 x @ w1,但我们需要 x @ w1^T
# 由于 w1 是转置后的权重 [K, N],所以 x @ w1 就是我们要的 x @ w1^T
acc1 += tl.dot(x, w1)
acc2 += tl.dot(x, w2)
# 移动指针到下一个block
x_ptrs += BLOCK_K * stride_xk
w1_ptrs += BLOCK_K * stride_w1k
w2_ptrs += BLOCK_K * stride_w2k
# 应用激活函数
if ACTIVATION == "silu":
# SiLU(x) = x * sigmoid(x)
sig = tl.sigmoid(acc1)
out = acc1 * sig * acc2 # SiLU(w1*x) * (w2*x)
# elif ACTIVATION == "gelu":
# # GELU 近似
# # GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
# gelu_approx = 0.5 * acc1 * (1 + tl.tanh(0.79788456 * (acc1 + 0.044715 * acc1 * acc1 * acc1)))
# out = gelu_approx * acc2
# else:
# # 无激活函数
# out = acc1 * acc2
# 存储结果
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def fused_gated_proj(x, w1, w2, activation="silu"):
"""
x: [M, K] - input
w1: [N, K] - weight1 (PyTorch Linear weight, 形状为 [out_features, in_features])
w2: [N, K] - weight2 (PyTorch Linear weight, 形状为 [out_features, in_features])
返回: [M, N]
计算: activation(w1 @ x^T)^T * (w2 @ x^T)^T
等价于: SiLU(x @ w1^T) * (x @ w2^T)
"""
assert x.dtype == torch.bfloat16
assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16
M, K = x.shape # M=1, K=4096
N, K2 = w1.shape # N=4096 K2=11264
assert K == K2, f"Dimension mismatch: x K={K}, w1 K={K2}"
assert w2.shape == (N, K), f"w2 shape mismatch: expected {(N, K)}, got {w2.shape}"
# 提前转置权重到 [K, N] 格式
w1_t = w1.t().contiguous() # [K, N]
w2_t = w2.t().contiguous() # [K, N]
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N'])
)
gated_proj_kernel[grid](
x, w1_t, w2_t, out, # 传入转置后的权重
M, K, N,
x.stride(0), x.stride(1),
w1_t.stride(0), w1_t.stride(1), # [K, N] 的 stride
w2_t.stride(0), w2_t.stride(1),
out.stride(0), out.stride(1),
ACTIVATION=activation,
BLOCK_M=64,
BLOCK_N=64,
BLOCK_K=32,
)
return out
class ParallelGatedMLP(nn.Module):
def __init__(self):
super().__init__()
self.act = F.silu
self.act_type = "silu"
self.l1 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l2 = nn.Linear(
in_features=4096,
out_features=11264,
bias=False,
)
self.l3 = nn.Linear(
in_features=11264,
out_features=4096,
bias=False,
)
def forward_org(self, z):
"""原始实现"""
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
# PyTorch: F.linear(x, weight) = x @ weight^T
# z1 = F.linear(z_flat, self.l1.weight) # [M, N]
# z2 = F.linear(z_flat, self.l2.weight) # [M, N]
z1, z2 = self.l1(z_flat), self.l2(z_flat)
gated = self.act(z1) * z2
return gated
def forward_opt(self, z):
"""Triton优化实现"""
shape = z.shape
z_flat = z.view(-1, shape[-1]) # [M, K]
# Triton 路径
gated = fused_gated_proj(
z_flat,
self.l1.weight, # [N, K]
self.l2.weight, # [N, K]
activation=self.act_type
)
return gated
if __name__ == "__main__":
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# 创建模型实例
model = ParallelGatedMLP()
model = model.to(dtype=torch.bfloat16, device="cuda:0")
# 测试不同的batch size
for batch_size in [1]:
print(f"\n{'='*50}")
print(f"Testing with batch_size={batch_size}")
print('='*50)
x = torch.randn(batch_size, 1, 4096, dtype=torch.bfloat16, device="cuda:0")
with torch.no_grad():
# 预热
for _ in range(3):
_ = model.forward_org(x)
_ = model.forward_opt(x)
t0 = time.time()
# 计算原始版本
result_org = model.forward_org(x)
t1 = time.time()
print(f"Time taken for forward_org: {t1 - t0:.4f} seconds")
# 计算优化版本
result_opt = model.forward_opt(x)
print(f"Time taken for forward_org: {time.time() - t1:.4f} seconds")
# 验证结果
print(f"ORG shape: {result_org.shape}")
print(f"OPT shape: {result_opt.shape}")
# 计算差异
diff = torch.abs(result_org - result_opt)
print(f"Max diff: {diff.max().item():.6f}")
print(f"Mean diff: {diff.mean().item():.6f}")
print(f"Min diff: {diff.min().item():.6f}")
# 相对误差
rel_error = diff / (torch.abs(result_org) + 1e-8)
print(f"Max relative error: {rel_error.max().item():.6f}")
print(f"Mean relative error: {rel_error.mean().item():.6f}")
# 验证前几个值
print("\nFirst 10 values comparison:")
print(f"ORG: {result_org[0, :10].float().cpu().numpy()}")
print(f"OPT: {result_opt[0, :10].float().cpu().numpy()}")
print(f"Diff: {diff[0, :10].float().cpu().numpy()}")
# 检查是否匹配
if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3):
print("✓ Results match within tolerance!")
else:
print("✗ Results do not match!")
# # 额外的验证:检查数学等价性
# print(f"\n{'='*50}")
# print("Additional mathematical verification")
# print('='*50)
# # 使用小矩阵验证
# test_x = torch.randn(2, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w1 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w2 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# # PyTorch 计算
# z1_pt = F.linear(test_x, test_w1) # x @ w1^T
# z2_pt = F.linear(test_x, test_w2) # x @ w2^T
# result_pt = F.silu(z1_pt) * z2_pt
# # Triton 计算
# result_triton = fused_gated_proj(test_x, test_w1, test_w2, activation="silu")
# diff_test = torch.abs(result_pt - result_triton)
# print(f"Test max diff: {diff_test.max().item():.6f}")
# print(f"Test mean diff: {diff_test.mean().item():.6f}")
# if torch.allclose(result_pt, result_triton, rtol=1e-2, atol=1e-3):
# print("✓ Test passed: Triton implementation matches PyTorch!")
# else:
# print("✗ Test failed: Triton implementation doesn't match PyTorch!")
\ No newline at end of file
import torch
import torch.nn as nn
import torch.nn.functional as F
import triton
import triton.language as tl
import time
import torch
import numpy as np
import random
seed = 1111
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if using multi-GPU
np.random.seed(seed)
random.seed(seed)
# 可选:牺牲性能以换取可复现性(因为某些 CUDA 操作是非确定性的)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
@triton.jit
def gated_proj_kernel(
x_ptr, w1_ptr, w2_ptr, out_ptr,
M, K, N,
stride_xm, stride_xk,
stride_wk, stride_wn, # w is [N, K], so stride_wn = K
stride_om, stride_on,
ACTIVATION: tl.constexpr,
BLOCK_M: tl.constexpr = 64,
BLOCK_N: tl.constexpr = 64,
BLOCK_K: tl.constexpr = 32,
):
pid_m = tl.program_id(0)
pid_n = tl.program_id(1)
offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
offs_k = tl.arange(0, BLOCK_K)
x_ptrs = x_ptr + offs_m[:, None] * stride_xm + offs_k[None, :] * stride_xk
w1_ptrs = w1_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
w2_ptrs = w2_ptr + offs_n[:, None] * stride_wn + offs_k[None, :] * stride_wk
acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
for k in range(0, K, BLOCK_K):
k_mask = offs_k[None, :] < K - k
x = tl.load(x_ptrs, mask=(offs_m[:, None] < M) & k_mask, other=0.0)
w1 = tl.load(w1_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
w2 = tl.load(w2_ptrs, mask=(offs_n[:, None] < N) & k_mask, other=0.0)
acc1 += tl.dot(x, w1.T)
acc2 += tl.dot(x, w2.T)
x_ptrs += BLOCK_K * stride_xk
w1_ptrs += BLOCK_K * stride_wk
w2_ptrs += BLOCK_K * stride_wk
offs_k += BLOCK_K
z1 = acc1.to(tl.float32)
z2 = acc2.to(tl.float32)
if ACTIVATION == "silu":
sig = tl.sigmoid(z1)
out = z1 * sig * z2
elif ACTIVATION == "gelu":
# Triton 没有 gelu,可近似或回退
# out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
sig = tl.sigmoid(z1)
out = z1 * sig * z2
else:
out = z1 * z2
out_ptrs = out_ptr + offs_m[:, None] * stride_om + offs_n[None, :] * stride_on
tl.store(out_ptrs, out.to(tl.bfloat16), mask=(offs_m[:, None] < M) & (offs_n[None, :] < N))
def fused_gated_proj(x, w1, w2, activation="silu"):
assert x.dtype == torch.bfloat16
assert w1.dtype == torch.bfloat16 and w2.dtype == torch.bfloat16
M, K = x.shape
N, _ = w1.shape
assert w2.shape == (N, K)
out = torch.empty(M, N, dtype=torch.bfloat16, device=x.device)
grid = lambda META: (
triton.cdiv(M, META['BLOCK_M']),
triton.cdiv(N, META['BLOCK_N'])
)
gated_proj_kernel[grid](
x, w1, w2, out,
M, K, N,
x.stride(0), x.stride(1),
w1.stride(1), w1.stride(0),
out.stride(0), out.stride(1),
ACTIVATION=activation,
BLOCK_M=64,
BLOCK_N=64,
BLOCK_K=32,
)
return out
# 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回)
def grab_first_if_tuple(x):
return x[0] if isinstance(x, tuple) else x
class ParallelGatedMLP(nn.Module):
def __init__(self, config, layer_idx):
super().__init__()
self.layer_idx = layer_idx
multiple_of = config.get("inner_size_multiple_of", 64)
self.act_type = config.get("mlp_activation", "gelu")
if self.act_type == "gelu":
self.act = F.gelu
elif self.act_type == "silu":
self.act = F.silu
else:
raise NotImplementedError
self.act_type = "silu"
if self.layer_idx > 0 and config.get("evo2_style_activations", False):
self.act = nn.Identity()
inner_size = 11264
self.l1 = nn.Linear(
in_features=config.get("hidden_size", 4096),
out_features=inner_size,
bias=False,
)
self.l2 = nn.Linear(
in_features=config.get("hidden_size", 4096),
out_features=inner_size,
bias=False,
)
self.l3 = nn.Linear(
in_features=inner_size,
out_features=config.get("hidden_size", 4096),
bias=False,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self.l1.weight = torch.nn.Parameter(self.l1.weight.contiguous())
self.l2.weight = torch.nn.Parameter(self.l2.weight.contiguous())
self.l3.weight = torch.nn.Parameter(self.l3.weight.contiguous())
def forward(self, z):
# z: [B, S, D] → flatten to [M, D]
shape = z.shape
z_flat = z.view(-1, int(shape[-1])) # [M, D]
# Triton 路径
gated = fused_gated_proj(
z_flat,
self.l1.weight, # [inner, hidden]
self.l2.weight,
activation=self.act_type
)
# y_flat = self.l3(gated) # [M, D]
# y = y_flat.view(*shape)
return gated
def forward_org(self, z):
shape = z.shape
z_flat = z.view(-1, shape[-1])
# GELU 或调试时走原生路径
z1, z2 = self.l1(z_flat), self.l2(z_flat)
gated = self.act(z1) * z2
return gated
def forward_opt(self, z):
# z: [B, S, D] → flatten to [M, D]
shape = z.shape
z_flat = z.view(-1, int(shape[-1])) # [M, D]
# Triton 路径
gated = fused_gated_proj(
z_flat,
self.l1.weight, # [inner, hidden]
self.l2.weight,
activation=self.act_type
)
return gated
if __name__ == "__main__":
# 模拟配置
config = {
"hidden_size": 4096,
"mlp_activation": "silu",
"model_parallel_size:q": 1,
"evo2_style_activations": False,
}
layer_idx = 0
# 创建模型实例
model = ParallelGatedMLP(config, layer_idx)
# 将模型转换为 bfloat16
model = model.to(dtype=torch.bfloat16, device="cuda:0")
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device = "cuda:0" # 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x = torch.randn(1, 1, 4096, dtype=torch.bfloat16, device=device)
with torch.no_grad():
result_org = model.forward_org(x)
print(f"ORG: {result_org[0, :20]}")
result_opt = model.forward_opt(x)
print(f"OPT: {result_opt[0, :20]}")
# 推理
# with torch.no_grad():
# for i in range(1000):
# output = model(x)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment