Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
wangkx1
ai-compiler
Commits
c601083d
Commit
c601083d
authored
Apr 14, 2026
by
liuys
🏸
Browse files
update triton
parent
2add9fa3
Changes
17
Hide whitespace changes
Inline
Side-by-side
Showing
17 changed files
with
2495 additions
and
0 deletions
+2495
-0
clean_hipprof.sh
clean_hipprof.sh
+4
-0
gemm_cutss/demo.cpp
gemm_cutss/demo.cpp
+576
-0
gemm_cutss/demo.cu
gemm_cutss/demo.cu
+274
-0
triton/oligoformer-opt/opt/__pycache__/org.cpython-310.pyc
triton/oligoformer-opt/opt/__pycache__/org.cpython-310.pyc
+0
-0
triton/oligoformer-opt/opt/demo-opt.py
triton/oligoformer-opt/opt/demo-opt.py
+77
-0
triton/oligoformer-opt/opt/demo.py
triton/oligoformer-opt/opt/demo.py
+133
-0
triton/oligoformer-opt/opt/org.py
triton/oligoformer-opt/opt/org.py
+14
-0
triton/oligoformer-opt/opt/trition_opt.py
triton/oligoformer-opt/opt/trition_opt.py
+8
-0
triton/oligoformer-opt/org_code/case-0-opt1.py
triton/oligoformer-opt/org_code/case-0-opt1.py
+272
-0
triton/oligoformer-opt/org_code/case-0.py
triton/oligoformer-opt/org_code/case-0.py
+202
-0
triton/oligoformer-opt/org_code/demo.py
triton/oligoformer-opt/org_code/demo.py
+78
-0
triton/oligoformer-opt/org_code/matmul-sample.py
triton/oligoformer-opt/org_code/matmul-sample.py
+170
-0
triton/oligoformer-opt/org_code/run.sh
triton/oligoformer-opt/org_code/run.sh
+4
-0
triton/oligoformer-opt/org_code/samples/clean_hipprof.sh
triton/oligoformer-opt/org_code/samples/clean_hipprof.sh
+4
-0
triton/oligoformer-opt/org_code/samples/matmul-sample.py
triton/oligoformer-opt/org_code/samples/matmul-sample.py
+177
-0
triton/oligoformer-opt/org_code/samples/mlp-sample.py
triton/oligoformer-opt/org_code/samples/mlp-sample.py
+270
-0
triton/oligoformer-opt/org_code/trition_opt.py
triton/oligoformer-opt/org_code/trition_opt.py
+232
-0
No files found.
clean_hipprof.sh
0 → 100644
View file @
c601083d
rm
-rf
*
.db
rm
-rf
*
.csv
rm
-rf
*
.txt
rm
-rf
*
.json
\ No newline at end of file
gemm_cutss/demo.cpp
0 → 100644
View file @
c601083d
/***************************************************************************************************
* Copyright (c) 2017 - 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
* modification, are permitted provided that the following conditions are met:
*
* 1. Redistributions of source code must retain the above copyright notice, this
* list of conditions and the following disclaimer.
*
* 2. Redistributions in binary form must reproduce the above copyright notice,
* this list of conditions and the following disclaimer in the documentation
* and/or other materials provided with the distribution.
*
* 3. Neither the name of the copyright holder nor the names of its
* contributors may be used to endorse or promote products derived from
* this software without specific prior written permission.
*
* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS"
* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*
**************************************************************************************************/
/*
This example demonstrates how to call a CUTLASS GEMM kernel and provides a naive reference
matrix multiply kernel to verify its correctness.
The CUTLASS Gemm template is instantiated in the function CutlassSgemmNN. This is kernel computes
the general matrix product (GEMM) using single-precision floating-point arithmetic and assumes
all matrices have column-major layout.
The threadblock tile size is chosen as 128x128x8 which offers good performance for large matrices.
See the CUTLASS Parallel for All blog post for more exposition on the tunable parameters available
in CUTLASS.
https://devblogs.nvidia.com/cutlass-linear-algebra-cuda/
Aside from defining and launching the SGEMM kernel, this example does not use any other components
or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are
prevalent in the CUTLASS unit tests.
This example has delibrately been kept similar to the basic_gemm example from cutlass-1.3 to
highlight the minimum amount of differences needed to transition to cutlass-2.0.
Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu
*/
// Standard Library includes
#include <iostream>
#include <sstream>
#include <vector>
// Helper methods to check for errors
#include "helper.h"
//
// CUTLASS includes needed for single-precision GEMM kernel
//
// Defines cutlass::gemm::device::Gemm, the generic Gemm computation template class.
#include "cutlass/gemm/device/gemm.h"
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// This function defines a CUTLASS GEMM kernel instantiation, constructs its parameters object,
// and launches it on the CUDA device.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Define a CUTLASS GEMM template and launch a GEMM kernel.
cudaError_t
CutlassSgemmNN
(
int
M
,
int
N
,
int
K
,
float
alpha
,
float
const
*
A
,
int
lda
,
float
const
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
// Define type definition for single-precision CUTLASS GEMM with column-major
// input matrices and 128x128x8 threadblock tile size (chosen by default).
//
// To keep the interface manageable, several helpers are defined for plausible compositions
// including the following example for single-precision GEMM. Typical values are used as
// default template arguments. See `cutlass/gemm/device/default_gemm_configuration.h` for more details.
//
// To view the full gemm device API interface, see `cutlass/gemm/device/gemm.h`
using
ColumnMajor
=
cutlass
::
layout
::
ColumnMajor
;
using
CutlassGemm
=
cutlass
::
gemm
::
device
::
Gemm
<
float
,
// Data-type of A matrix
ColumnMajor
,
// Layout of A matrix
float
,
// Data-type of B matrix
ColumnMajor
,
// Layout of B matrix
float
,
// Data-type of C matrix
ColumnMajor
>
;
// Layout of C matrix
// Define a CUTLASS GEMM type
CutlassGemm
gemm_operator
;
// Construct the CUTLASS GEMM arguments object.
//
// One of CUTLASS's design patterns is to define gemm argument objects that are constructible
// in host code and passed to kernels by value. These may include pointers, strides, scalars,
// and other arguments needed by Gemm and its components.
//
// The benefits of this pattern are (1.) a structured, composable strategy for passing host-constructible
// arguments to kernels and (2.) minimized initialization overhead on kernel entry.
//
CutlassGemm
::
Arguments
args
({
M
,
N
,
K
},
// Gemm Problem dimensions
{
A
,
lda
},
// Tensor-ref for source matrix A
{
B
,
ldb
},
// Tensor-ref for source matrix B
{
C
,
ldc
},
// Tensor-ref for source matrix C
{
C
,
ldc
},
// Tensor-ref for destination matrix D (may be different memory than source C matrix)
{
alpha
,
beta
});
// Scalars used in the Epilogue
//
// Launch the CUTLASS GEMM kernel.
//
cutlass
::
Status
status
=
gemm_operator
(
args
);
//
// Return a cudaError_t if the CUTLASS GEMM operator returned an error code.
//
if
(
status
!=
cutlass
::
Status
::
kSuccess
)
{
return
cudaErrorUnknown
;
}
// Return success, if no errors were encountered.
return
cudaSuccess
;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
//
// The source code after this point in the file is generic CUDA using the CUDA Runtime API
// and simple CUDA kernels to initialize matrices and compute the general matrix product.
//
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Kernel to initialize a matrix with small integers.
__global__
void
InitializeMatrix_kernel
(
float
*
matrix
,
int
rows
,
int
columns
,
int
seed
=
0
)
{
int
i
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
j
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
if
(
i
<
rows
&&
j
<
columns
)
{
int
offset
=
i
+
j
*
rows
;
// Generate arbitrary elements.
int
const
k
=
16807
;
int
const
m
=
16
;
float
value
=
float
(((
offset
+
seed
)
*
k
%
m
)
-
m
/
2
);
matrix
[
offset
]
=
value
;
}
}
/// Simple function to initialize a matrix to arbitrary small integers.
cudaError_t
InitializeMatrix
(
float
*
matrix
,
int
rows
,
int
columns
,
int
seed
=
0
)
{
dim3
block
(
16
,
16
);
dim3
grid
(
(
rows
+
block
.
x
-
1
)
/
block
.
x
,
(
columns
+
block
.
y
-
1
)
/
block
.
y
);
InitializeMatrix_kernel
<<<
grid
,
block
>>>
(
matrix
,
rows
,
columns
,
seed
);
return
cudaGetLastError
();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocates device memory for a matrix then fills with arbitrary small integers.
cudaError_t
AllocateMatrix
(
float
**
matrix
,
int
rows
,
int
columns
,
int
seed
=
0
)
{
cudaError_t
result
;
size_t
sizeof_matrix
=
sizeof
(
float
)
*
rows
*
columns
;
// Allocate device memory.
result
=
cudaMalloc
(
reinterpret_cast
<
void
**>
(
matrix
),
sizeof_matrix
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Failed to allocate matrix: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
return
result
;
}
// Clear the allocation.
result
=
cudaMemset
(
*
matrix
,
0
,
sizeof_matrix
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Failed to clear matrix device memory: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
return
result
;
}
// Initialize matrix elements to arbitrary small integers.
result
=
InitializeMatrix
(
*
matrix
,
rows
,
columns
,
seed
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Failed to initialize matrix: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
return
result
;
}
return
result
;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Naive reference GEMM computation.
__global__
void
ReferenceGemm_kernel
(
int
M
,
int
N
,
int
K
,
float
alpha
,
float
const
*
A
,
int
lda
,
float
const
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
int
i
=
threadIdx
.
x
+
blockIdx
.
x
*
blockDim
.
x
;
int
j
=
threadIdx
.
y
+
blockIdx
.
y
*
blockDim
.
y
;
if
(
i
<
M
&&
j
<
N
)
{
float
accumulator
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
accumulator
+=
A
[
i
+
k
*
lda
]
*
B
[
k
+
j
*
ldb
];
}
C
[
i
+
j
*
ldc
]
=
alpha
*
accumulator
+
beta
*
C
[
i
+
j
*
ldc
];
}
}
/// Reference GEMM computation.
cudaError_t
ReferenceGemm
(
int
M
,
int
N
,
int
K
,
float
alpha
,
float
const
*
A
,
int
lda
,
float
const
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
dim3
block
(
16
,
16
);
dim3
grid
(
(
M
+
block
.
x
-
1
)
/
block
.
x
,
(
N
+
block
.
y
-
1
)
/
block
.
y
);
ReferenceGemm_kernel
<<<
grid
,
block
>>>
(
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
return
cudaGetLastError
();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
#define TILE_SIZE 16
__global__
void
TiledGemm_kernel
(
int
M
,
int
N
,
int
K
,
float
alpha
,
float
const
*
A
,
int
lda
,
float
const
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
__shared__
float
As
[
TILE_SIZE
][
TILE_SIZE
];
__shared__
float
Bs
[
TILE_SIZE
][
TILE_SIZE
];
int
tx
=
threadIdx
.
x
;
int
ty
=
threadIdx
.
y
;
int
row
=
blockIdx
.
y
*
TILE_SIZE
+
ty
;
int
col
=
blockIdx
.
x
*
TILE_SIZE
+
tx
;
float
accumulator
=
0
;
for
(
int
tile
=
0
;
tile
<
(
K
+
TILE_SIZE
-
1
)
/
TILE_SIZE
;
++
tile
)
{
if
(
row
<
M
&&
(
tile
*
TILE_SIZE
+
tx
)
<
K
)
{
As
[
ty
][
tx
]
=
A
[
row
+
(
tile
*
TILE_SIZE
+
tx
)
*
lda
];
}
else
{
As
[
ty
][
tx
]
=
0
;
}
if
(
col
<
N
&&
(
tile
*
TILE_SIZE
+
ty
)
<
K
)
{
Bs
[
ty
][
tx
]
=
B
[(
tile
*
TILE_SIZE
+
ty
)
+
col
*
ldb
];
}
else
{
Bs
[
ty
][
tx
]
=
0
;
}
__syncthreads
();
for
(
int
k
=
0
;
k
<
TILE_SIZE
;
++
k
)
{
accumulator
+=
As
[
ty
][
k
]
*
Bs
[
k
][
tx
];
}
__syncthreads
();
}
if
(
row
<
M
&&
col
<
N
)
{
C
[
row
+
col
*
ldc
]
=
alpha
*
accumulator
+
beta
*
C
[
row
+
col
*
ldc
];
}
}
cudaError_t
TiledGemm
(
int
M
,
int
N
,
int
K
,
float
alpha
,
float
const
*
A
,
int
lda
,
float
const
*
B
,
int
ldb
,
float
beta
,
float
*
C
,
int
ldc
)
{
dim3
block
(
TILE_SIZE
,
TILE_SIZE
);
dim3
grid
(
(
N
+
TILE_SIZE
-
1
)
/
TILE_SIZE
,
(
M
+
TILE_SIZE
-
1
)
/
TILE_SIZE
);
TiledGemm_kernel
<<<
grid
,
block
>>>
(
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
return
cudaGetLastError
();
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Allocate several matrices in GPU device memory and call a single-precision
/// CUTLASS GEMM kernel.
cudaError_t
TestCutlassGemm
(
int
M
,
int
N
,
int
K
,
float
alpha
,
float
beta
)
{
cudaError_t
result
;
//
// Define several matrices to be used as operands to GEMM kernels.
//
// Compute leading dimensions for each matrix.
int
lda
=
M
;
int
ldb
=
K
;
int
ldc
=
M
;
// Compute size in bytes of the C matrix.
size_t
sizeof_C
=
sizeof
(
float
)
*
ldc
*
N
;
// Define pointers to matrices in GPU device memory.
float
*
A
;
float
*
B
;
float
*
C_cutlass
;
float
*
C_reference
;
//
// Allocate matrices in GPU device memory with arbitrary seeds.
//
result
=
AllocateMatrix
(
&
A
,
M
,
K
,
0
);
if
(
result
!=
cudaSuccess
)
{
return
result
;
}
result
=
AllocateMatrix
(
&
B
,
K
,
N
,
17
);
if
(
result
!=
cudaSuccess
)
{
cudaFree
(
A
);
return
result
;
}
result
=
AllocateMatrix
(
&
C_cutlass
,
M
,
N
,
101
);
if
(
result
!=
cudaSuccess
)
{
cudaFree
(
A
);
cudaFree
(
B
);
return
result
;
}
result
=
AllocateMatrix
(
&
C_reference
,
M
,
N
,
101
);
if
(
result
!=
cudaSuccess
)
{
cudaFree
(
A
);
cudaFree
(
B
);
cudaFree
(
C_cutlass
);
return
result
;
}
result
=
cudaMemcpy
(
C_reference
,
C_cutlass
,
sizeof_C
,
cudaMemcpyDeviceToDevice
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Failed to copy C_cutlass matrix to C_reference: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
cudaFree
(
C_reference
);
cudaFree
(
C_cutlass
);
cudaFree
(
B
);
cudaFree
(
A
);
return
result
;
}
//
// Launch CUTLASS GEMM.
//
result
=
CutlassSgemmNN
(
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C_cutlass
,
ldc
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"CUTLASS GEMM kernel failed: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
cudaFree
(
C_reference
);
cudaFree
(
C_cutlass
);
cudaFree
(
B
);
cudaFree
(
A
);
return
result
;
}
//
// Verify.
//
// Launch reference GEMM
result
=
ReferenceGemm
(
M
,
N
,
K
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C_reference
,
ldc
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Reference GEMM kernel failed: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
cudaFree
(
C_reference
);
cudaFree
(
C_cutlass
);
cudaFree
(
B
);
cudaFree
(
A
);
return
result
;
}
// Copy to host and verify equivalence.
std
::
vector
<
float
>
host_cutlass
(
ldc
*
N
,
0
);
std
::
vector
<
float
>
host_reference
(
ldc
*
N
,
0
);
result
=
cudaMemcpy
(
host_cutlass
.
data
(),
C_cutlass
,
sizeof_C
,
cudaMemcpyDeviceToHost
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Failed to copy CUTLASS GEMM results: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
cudaFree
(
C_reference
);
cudaFree
(
C_cutlass
);
cudaFree
(
B
);
cudaFree
(
A
);
return
result
;
}
result
=
cudaMemcpy
(
host_reference
.
data
(),
C_reference
,
sizeof_C
,
cudaMemcpyDeviceToHost
);
if
(
result
!=
cudaSuccess
)
{
std
::
cerr
<<
"Failed to copy Reference GEMM results: "
<<
cudaGetErrorString
(
result
)
<<
std
::
endl
;
cudaFree
(
C_reference
);
cudaFree
(
C_cutlass
);
cudaFree
(
B
);
cudaFree
(
A
);
return
result
;
}
//
// Free device memory allocations.
//
cudaFree
(
C_reference
);
cudaFree
(
C_cutlass
);
cudaFree
(
B
);
cudaFree
(
A
);
//
// Test for bit equivalence of results.
//
if
(
host_cutlass
!=
host_reference
)
{
std
::
cerr
<<
"CUTLASS results incorrect."
<<
std
::
endl
;
return
cudaErrorUnknown
;
}
return
cudaSuccess
;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
/// Entry point to basic_gemm example.
//
// usage:
//
// 00_basic_gemm <M> <N> <K> <alpha> <beta>
//
int
main
(
int
argc
,
const
char
*
arg
[])
{
//
// Parse the command line to obtain GEMM dimensions and scalar values.
//
// GEMM problem dimensions.
int
problem
[
3
]
=
{
128
,
128
,
128
};
for
(
int
i
=
1
;
i
<
argc
&&
i
<
4
;
++
i
)
{
std
::
stringstream
ss
(
arg
[
i
]);
ss
>>
problem
[
i
-
1
];
}
// Scalars used for linear scaling the result of the matrix product.
float
scalars
[
2
]
=
{
1
,
0
};
for
(
int
i
=
4
;
i
<
argc
&&
i
<
6
;
++
i
)
{
std
::
stringstream
ss
(
arg
[
i
]);
ss
>>
scalars
[
i
-
4
];
}
//
// Run the CUTLASS GEMM test.
//
cudaError_t
result
=
TestCutlassGemm
(
problem
[
0
],
// GEMM M dimension
problem
[
1
],
// GEMM N dimension
problem
[
2
],
// GEMM K dimension
scalars
[
0
],
// alpha
scalars
[
1
]
// beta
);
if
(
result
==
cudaSuccess
)
{
std
::
cout
<<
"Passed."
<<
std
::
endl
;
}
// Exit.
return
result
==
cudaSuccess
?
0
:
-
1
;
}
///////////////////////////////////////////////////////////////////////////////////////////////////
\ No newline at end of file
gemm_cutss/demo.cu
0 → 100644
View file @
c601083d
#include <cuda_runtime.h>
#include <iostream>
#include <cmath>
#include <cstdlib>
#define CHECK_CUDA(call) \
do { \
cudaError_t err = call; \
if (err != cudaSuccess) { \
std::cerr << "CUDA error at " << __FILE__ << ":" << __LINE__ << " - " \
<< cudaGetErrorString(err) << std::endl; \
exit(EXIT_FAILURE); \
} \
} while (0)
constexpr
int
kMmaM
=
16
;
constexpr
int
kMmaN
=
16
;
constexpr
int
kMmaK
=
16
;
constexpr
int
kWarpM
=
64
;
constexpr
int
kWarpN
=
64
;
constexpr
int
kWarpK
=
32
;
constexpr
int
kBlockM
=
128
;
constexpr
int
kBlockN
=
128
;
constexpr
int
kBlockK
=
64
;
constexpr
int
kWarpNumM
=
kBlockM
/
kWarpM
;
constexpr
int
kWarpNumN
=
kBlockN
/
kWarpN
;
__global__
void
TiledGemmKernel
(
int
M
,
int
N
,
int
K
,
float
alpha
,
const
float
*
__restrict__
A
,
const
float
*
__restrict__
B
,
float
beta
,
float
*
__restrict__
C
)
{
const
int
lda
=
M
;
const
int
ldb
=
K
;
const
int
ldc
=
M
;
__shared__
float
smemA
[
kBlockM
][
kBlockK
];
__shared__
float
smemB
[
kBlockK
][
kBlockN
];
const
int
warpId
=
threadIdx
.
x
/
32
;
const
int
laneId
=
threadIdx
.
x
%
32
;
const
int
warpRow
=
warpId
/
kWarpNumN
;
const
int
warpCol
=
warpId
%
kWarpNumN
;
// 每个线程负责4x4的碎片计算
const
int
threadRowInWarp
=
laneId
/
4
;
const
int
threadColInWarp
=
laneId
%
4
;
const
int
blockRow
=
blockIdx
.
y
*
kBlockM
;
const
int
blockCol
=
blockIdx
.
x
*
kBlockN
;
// 每个线程负责4x4的结果,所以每个warp负责64x64
float
acc
[
4
][
4
]
=
{
0
};
const
int
numTiles
=
(
K
+
kBlockK
-
1
)
/
kBlockK
;
for
(
int
tileIdx
=
0
;
tileIdx
<
numTiles
;
++
tileIdx
)
{
// 加载A到共享内存 (M维度分块)
for
(
int
i
=
threadIdx
.
x
;
i
<
kBlockM
*
kBlockK
;
i
+=
blockDim
.
x
)
{
int
row
=
i
/
kBlockK
;
int
col
=
i
%
kBlockK
;
int
globalRow
=
blockRow
+
row
;
int
globalCol
=
tileIdx
*
kBlockK
+
col
;
if
(
globalRow
<
M
&&
globalCol
<
K
)
{
smemA
[
row
][
col
]
=
A
[
globalRow
+
globalCol
*
lda
];
}
else
{
smemA
[
row
][
col
]
=
0.0
f
;
}
}
// 加载B到共享内存 (N维度分块)
for
(
int
i
=
threadIdx
.
x
;
i
<
kBlockK
*
kBlockN
;
i
+=
blockDim
.
x
)
{
int
row
=
i
/
kBlockN
;
int
col
=
i
%
kBlockN
;
int
globalRow
=
tileIdx
*
kBlockK
+
row
;
int
globalCol
=
blockCol
+
col
;
if
(
globalRow
<
K
&&
globalCol
<
N
)
{
smemB
[
row
][
col
]
=
B
[
globalRow
+
globalCol
*
ldb
];
}
else
{
smemB
[
row
][
col
]
=
0.0
f
;
}
}
__syncthreads
();
// 计算当前tile
const
int
warpStartRow
=
warpRow
*
kWarpM
;
const
int
warpStartCol
=
warpCol
*
kWarpN
;
for
(
int
k
=
0
;
k
<
kBlockK
;
k
+=
kMmaK
)
{
// 每个线程加载4个A的元素
float
aFrag
[
4
];
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
int
row
=
warpStartRow
+
threadRowInWarp
+
i
*
4
;
int
col
=
k
+
threadColInWarp
;
aFrag
[
i
]
=
smemA
[
row
][
col
];
}
// 每个线程加载4个B的元素
float
bFrag
[
4
];
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
int
row
=
k
+
threadRowInWarp
;
int
col
=
warpStartCol
+
threadColInWarp
+
j
*
4
;
bFrag
[
j
]
=
smemB
[
row
][
col
];
}
// 计算外积并累加
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
acc
[
i
][
j
]
+=
aFrag
[
i
]
*
bFrag
[
j
];
}
}
}
__syncthreads
();
}
// 写回结果
const
int
warpStartRow
=
blockRow
+
warpRow
*
kWarpM
;
const
int
warpStartCol
=
blockCol
+
warpCol
*
kWarpN
;
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
int
row
=
warpStartRow
+
threadRowInWarp
+
i
*
4
;
if
(
row
>=
M
)
continue
;
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
int
col
=
warpStartCol
+
threadColInWarp
+
j
*
4
;
if
(
col
>=
N
)
continue
;
int
idx
=
row
+
col
*
ldc
;
C
[
idx
]
=
alpha
*
acc
[
i
][
j
]
+
beta
*
C
[
idx
];
}
}
}
void
TiledGemm
(
int
M
,
int
N
,
int
K
,
float
alpha
,
const
float
*
A
,
const
float
*
B
,
float
beta
,
float
*
C
)
{
dim3
block
(
256
);
dim3
grid
(
(
N
+
kBlockN
-
1
)
/
kBlockN
,
(
M
+
kBlockM
-
1
)
/
kBlockM
);
TiledGemmKernel
<<<
grid
,
block
>>>
(
M
,
N
,
K
,
alpha
,
A
,
B
,
beta
,
C
);
CHECK_CUDA
(
cudaDeviceSynchronize
());
}
void
ReferenceGemm
(
int
M
,
int
N
,
int
K
,
float
alpha
,
const
float
*
A
,
const
float
*
B
,
float
beta
,
float
*
C
)
{
for
(
int
i
=
0
;
i
<
M
;
++
i
)
{
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
sum
=
0
;
for
(
int
k
=
0
;
k
<
K
;
++
k
)
{
sum
+=
A
[
i
+
k
*
M
]
*
B
[
k
+
j
*
K
];
}
C
[
i
+
j
*
M
]
=
alpha
*
sum
+
beta
*
C
[
i
+
j
*
M
];
}
}
}
void
RandomInit
(
float
*
data
,
int
size
)
{
for
(
int
i
=
0
;
i
<
size
;
++
i
)
{
data
[
i
]
=
(
float
(
rand
())
/
RAND_MAX
)
*
2.0
f
-
1.0
f
;
}
}
bool
Verify
(
const
float
*
C1
,
const
float
*
C2
,
int
M
,
int
N
,
float
tolerance
=
1e-3
f
)
{
for
(
int
i
=
0
;
i
<
M
;
++
i
)
{
for
(
int
j
=
0
;
j
<
N
;
++
j
)
{
float
diff
=
fabsf
(
C1
[
i
+
j
*
M
]
-
C2
[
i
+
j
*
M
]);
if
(
diff
>
tolerance
)
{
std
::
cerr
<<
"Mismatch at C["
<<
i
<<
","
<<
j
<<
"]: "
<<
C1
[
i
+
j
*
M
]
<<
" vs "
<<
C2
[
i
+
j
*
M
]
<<
" (diff="
<<
diff
<<
")"
<<
std
::
endl
;
return
false
;
}
}
}
return
true
;
}
int
main
(
int
argc
,
char
**
argv
)
{
int
M
=
512
;
int
N
=
512
;
int
K
=
512
;
float
alpha
=
1.0
f
;
float
beta
=
0.0
f
;
if
(
argc
>=
4
)
{
M
=
atoi
(
argv
[
1
]);
N
=
atoi
(
argv
[
2
]);
K
=
atoi
(
argv
[
3
]);
}
std
::
cout
<<
"GEMM: M="
<<
M
<<
", N="
<<
N
<<
", K="
<<
K
<<
std
::
endl
;
float
*
h_A
,
*
h_B
,
*
h_C_tiled
,
*
h_C_ref
;
float
*
d_A
,
*
d_B
,
*
d_C
;
h_A
=
new
float
[
M
*
K
];
h_B
=
new
float
[
K
*
N
];
h_C_tiled
=
new
float
[
M
*
N
];
h_C_ref
=
new
float
[
M
*
N
];
RandomInit
(
h_A
,
M
*
K
);
RandomInit
(
h_B
,
K
*
N
);
CHECK_CUDA
(
cudaMalloc
(
&
d_A
,
M
*
K
*
sizeof
(
float
)));
CHECK_CUDA
(
cudaMalloc
(
&
d_B
,
K
*
N
*
sizeof
(
float
)));
CHECK_CUDA
(
cudaMalloc
(
&
d_C
,
M
*
N
*
sizeof
(
float
)));
CHECK_CUDA
(
cudaMemcpy
(
d_A
,
h_A
,
M
*
K
*
sizeof
(
float
),
cudaMemcpyHostToDevice
));
CHECK_CUDA
(
cudaMemcpy
(
d_B
,
h_B
,
K
*
N
*
sizeof
(
float
),
cudaMemcpyHostToDevice
));
CHECK_CUDA
(
cudaMemset
(
d_C
,
0
,
M
*
N
*
sizeof
(
float
)));
cudaEvent_t
start
,
stop
;
CHECK_CUDA
(
cudaEventCreate
(
&
start
));
CHECK_CUDA
(
cudaEventCreate
(
&
stop
));
CHECK_CUDA
(
cudaEventRecord
(
start
));
TiledGemm
(
M
,
N
,
K
,
alpha
,
d_A
,
d_B
,
beta
,
d_C
);
CHECK_CUDA
(
cudaEventRecord
(
stop
));
CHECK_CUDA
(
cudaEventSynchronize
(
stop
));
float
milliseconds
=
0
;
CHECK_CUDA
(
cudaEventElapsedTime
(
&
milliseconds
,
start
,
stop
));
CHECK_CUDA
(
cudaMemcpy
(
h_C_tiled
,
d_C
,
M
*
N
*
sizeof
(
float
),
cudaMemcpyDeviceToHost
));
ReferenceGemm
(
M
,
N
,
K
,
alpha
,
h_A
,
h_B
,
beta
,
h_C_ref
);
bool
passed
=
Verify
(
h_C_tiled
,
h_C_ref
,
M
,
N
);
float
tflops
=
(
2.0
f
*
M
*
N
*
K
)
/
(
milliseconds
*
1e-3
f
)
/
1e12
f
;
std
::
cout
<<
"Tiled GEMM: "
<<
milliseconds
<<
" ms"
<<
std
::
endl
;
std
::
cout
<<
"Performance: "
<<
tflops
<<
" TFLOPS"
<<
std
::
endl
;
std
::
cout
<<
"Result: "
<<
(
passed
?
"PASSED"
:
"FAILED"
)
<<
std
::
endl
;
CHECK_CUDA
(
cudaEventDestroy
(
start
));
CHECK_CUDA
(
cudaEventDestroy
(
stop
));
CHECK_CUDA
(
cudaFree
(
d_A
));
CHECK_CUDA
(
cudaFree
(
d_B
));
CHECK_CUDA
(
cudaFree
(
d_C
));
delete
[]
h_A
;
delete
[]
h_B
;
delete
[]
h_C_tiled
;
delete
[]
h_C_ref
;
return
passed
?
0
:
1
;
}
\ No newline at end of file
triton/oligoformer-opt/opt/__pycache__/org.cpython-310.pyc
0 → 100644
View file @
c601083d
File added
triton/oligoformer-opt/opt/demo-opt.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回)
def
grab_first_if_tuple
(
x
):
return
x
[
0
]
if
isinstance
(
x
,
tuple
)
else
x
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
):
super
().
__init__
()
self
.
layer_idx
=
layer_idx
multiple_of
=
config
.
get
(
"inner_size_multiple_of"
,
64
)
self
.
act_type
=
config
.
get
(
"mlp_activation"
,
"gelu"
)
if
self
.
act_type
==
"gelu"
:
self
.
act
=
F
.
gelu
elif
self
.
act_type
==
"silu"
:
self
.
act
=
F
.
silu
else
:
raise
NotImplementedError
if
self
.
layer_idx
>
0
and
config
.
get
(
"evo2_style_activations"
,
False
):
self
.
act
=
nn
.
Identity
()
inner_size
=
11264
self
.
l1
=
nn
.
Linear
(
in_features
=
config
.
get
(
"hidden_size"
,
4096
),
out_features
=
inner_size
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
config
.
get
(
"hidden_size"
,
4096
),
out_features
=
inner_size
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
inner_size
,
out_features
=
config
.
get
(
"hidden_size"
,
4096
),
bias
=
False
,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self
.
l1
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l1
.
weight
.
contiguous
())
self
.
l2
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l2
.
weight
.
contiguous
())
self
.
l3
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l3
.
weight
.
contiguous
())
def
forward
(
self
,
z
):
z1
,
z2
=
self
.
l1
(
z
),
self
.
l2
(
z
)
return
z1
,
z2
# === 示例调用 ===
if
__name__
==
"__main__"
:
# 模拟配置
config
=
{
"hidden_size"
:
4096
,
"mlp_activation"
:
"silu"
,
"model_parallel_size:q"
:
1
,
"evo2_style_activations"
:
False
,
}
layer_idx
=
0
# 创建模型实例
model
=
ParallelGatedMLP
(
config
,
layer_idx
)
# 将模型转换为 bfloat16
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device
=
"cuda:0"
# 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x
=
torch
.
randn
(
1
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
# 推理
with
torch
.
no_grad
():
output
=
model
(
x
)
\ No newline at end of file
triton/oligoformer-opt/opt/demo.py
0 → 100644
View file @
c601083d
import
triton
import
triton.language
as
tl
# @triton.jit
# def gated_mlp_kernel(
# # 输入
# x_ptr, # [M, K]
# w1_ptr, # [N, K] -> 注意:w1 是 out_features x in_features
# w2_ptr,
# w3_ptr, # [K_out, N] = [hidden, inner]
# y_ptr, # output [M, K_out]
# # 形状
# M, # batch * seq_len
# K, # hidden_size (e.g., 4096)
# N, # inner_size (e.g., 11264)
# K_out: tl.constexpr,
# # 分块
# BLOCK_M: tl.constexpr = 64,
# BLOCK_N: tl.constexpr = 128,
# BLOCK_K: tl.constexpr = 64,
# ):
# pid_m = tl.program_id(0)
# pid_n = tl.program_id(1)
# # 计算当前 block 覆盖的输出区域: [pid_m*BLOCK_M : ..., pid_n*BLOCK_N : ...]
# offs_m = pid_m * BLOCK_M + tl.arange(0, BLOCK_M)
# offs_n = pid_n * BLOCK_N + tl.arange(0, BLOCK_N)
# offs_k = tl.arange(0, BLOCK_K)
# # 加载 x 的一行(或几行)
# x_ptrs = x_ptr + offs_m[:, None] * K + offs_k[None, :]
# w1_ptrs = w1_ptr + offs_n[:, None] * K + offs_k[None, :]
# w2_ptrs = w2_ptr + offs_n[:, None] * K + offs_k[None, :]
# # 初始化累加器
# acc1 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# acc2 = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
# for k in range(0, K, BLOCK_K):
# # 边界处理
# k_mask = (offs_k[None, :] < K - k)
# x = tl.load(x_ptrs, mask=k_mask, other=0.0)
# w1 = tl.load(w1_ptrs, mask=k_mask, other=0.0)
# w2 = tl.load(w2_ptrs, mask=k_mask, other=0.0)
# acc1 += tl.dot(x, w1.T)
# acc2 += tl.dot(x, w2.T)
# x_ptrs += BLOCK_K
# w1_ptrs += BLOCK_K
# w2_ptrs += BLOCK_K
# offs_k += BLOCK_K
# # 应用 SiLU: x * sigmoid(x)
# z1 = acc1.to(tl.bfloat16)
# z2 = acc2.to(tl.bfloat16)
# sig = tl.sigmoid(z1)
# gated = z1 * sig * z2 # SiLU(z1) * z2
# # 第二阶段:gated @ w3.T → [M, N] @ [K_out, N].T = [M, K_out]
# # 注意:w3 是 [K_out, N],我们要做 gated (M,N) × w3.T (N, K_out)
# offs_k2 = tl.arange(0, BLOCK_K)
# w3_ptrs = w3_ptr + offs_n[:, None] + offs_k2[None, :] * N # w3[k_out, n] → 列主序?
# # 更安全的方式:假设 w3 是 [K_out, N],按行存储,则 w3[k, n] = w3_ptr[k*N + n]
# # 所以要加载 w3 的第 n 列 → 需要转置视角
# # 我们改用:对每个输出列 k_out,累加 gated[:, n] * w3[k_out, n]
# # 所以启动 grid 时,pid_n 对应 k_out,需要调整逻辑
# # ⚠️ 上面的设计有问题!更好的方式是分两个 kernel:
# # 1. 计算 gated = SiLU(x@W1) * (x@W2) → [M, N]
# # 2. gated @ W3.T → [M, K_out]
# # 因为 N=11264 很大,直接三重融合会导致寄存器溢出
# # 因此,我们只融合前两步 + activation,第三步用 cuBLAS(torch.matmul)
@
triton
.
jit
def
gated_proj_kernel
(
x_ptr
,
w1_ptr
,
w2_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_wk
,
stride_wn
,
# w is [N, K], so stride_wn = K
stride_om
,
stride_on
,
ACTIVATION
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
w1_ptrs
=
w1_ptr
+
offs_n
[:,
None
]
*
stride_wn
+
offs_k
[
None
,
:]
*
stride_wk
w2_ptrs
=
w2_ptr
+
offs_n
[:,
None
]
*
stride_wn
+
offs_k
[
None
,
:]
*
stride_wk
acc1
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
acc2
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
k_mask
=
offs_k
[
None
,
:]
<
K
-
k
x
=
tl
.
load
(
x_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
M
)
&
k_mask
,
other
=
0.0
)
w1
=
tl
.
load
(
w1_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
N
)
&
k_mask
,
other
=
0.0
)
w2
=
tl
.
load
(
w2_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
N
)
&
k_mask
,
other
=
0.0
)
acc1
+=
tl
.
dot
(
x
,
w1
.
T
)
acc2
+=
tl
.
dot
(
x
,
w2
.
T
)
x_ptrs
+=
BLOCK_K
*
stride_xk
w1_ptrs
+=
BLOCK_K
*
stride_wk
w2_ptrs
+=
BLOCK_K
*
stride_wk
offs_k
+=
BLOCK_K
z1
=
acc1
.
to
(
tl
.
bfloat16
)
z2
=
acc2
.
to
(
tl
.
bfloat16
)
if
ACTIVATION
==
"silu"
:
sig
=
tl
.
sigmoid
(
z1
)
out
=
z1
*
sig
*
z2
elif
ACTIVATION
==
"gelu"
:
# Triton 没有 gelu,可近似或回退
out
=
z1
*
0.5
*
(
1
+
tl
.
tanh
(
0.79788456
*
(
z1
+
0.044715
*
z1
*
z1
*
z1
)))
*
z2
else
:
out
=
z1
*
z2
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
,
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
\ No newline at end of file
triton/oligoformer-opt/opt/org.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
if
__name__
==
"__main__"
:
# 模拟配置
config
=
{
"hidden_size"
:
4096
,
"mlp_activation"
:
"silu"
,
"model_parallel_size:q"
:
1
,
"evo2_style_activations"
:
False
,
}
\ No newline at end of file
triton/oligoformer-opt/opt/trition_opt.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
if
__name__
==
"__main__"
:
pass
\ No newline at end of file
triton/oligoformer-opt/org_code/case-0-opt1.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
numpy
as
np
import
random
import
time
@
triton
.
jit
def
gated_proj_kernel
(
x_ptr
,
w1_ptr
,
w2_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_w1k
,
stride_w1n
,
# w1 is [K, N]
stride_w2k
,
stride_w2n
,
# w2 is [K, N]
stride_om
,
stride_on
,
ACTIVATION
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
# x: [M, K]
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
# w1 and w2: [K, N] (转置后的权重)
# 注意:w1_ptr 和 w2_ptr 已经指向转置后的权重
w1_ptrs
=
w1_ptr
+
offs_k
[:,
None
]
*
stride_w1k
+
offs_n
[
None
,
:]
*
stride_w1n
w2_ptrs
=
w2_ptr
+
offs_k
[:,
None
]
*
stride_w2k
+
offs_n
[
None
,
:]
*
stride_w2n
acc1
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
acc2
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
# 加载 x
x_mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_k
[
None
,
:]
<
K
-
k
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
x_mask
,
other
=
0.0
)
# 加载 w1 和 w2
w_mask
=
(
offs_k
[:,
None
]
<
K
-
k
)
&
(
offs_n
[
None
,
:]
<
N
)
w1
=
tl
.
load
(
w1_ptrs
,
mask
=
w_mask
,
other
=
0.0
)
w2
=
tl
.
load
(
w2_ptrs
,
mask
=
w_mask
,
other
=
0.0
)
# 计算点积: x @ w1^T 和 x @ w2^T
# x: [BLOCK_M, BLOCK_K], w1: [BLOCK_K, BLOCK_N]
# tl.dot(x, w1) 计算的是 x @ w1,但我们需要 x @ w1^T
# 由于 w1 是转置后的权重 [K, N],所以 x @ w1 就是我们要的 x @ w1^T
acc1
+=
tl
.
dot
(
x
,
w1
)
acc2
+=
tl
.
dot
(
x
,
w2
)
# 移动指针到下一个block
x_ptrs
+=
BLOCK_K
*
stride_xk
w1_ptrs
+=
BLOCK_K
*
stride_w1k
w2_ptrs
+=
BLOCK_K
*
stride_w2k
# 应用激活函数
if
ACTIVATION
==
"silu"
:
# SiLU(x) = x * sigmoid(x)
sig
=
tl
.
sigmoid
(
acc1
)
out
=
acc1
*
sig
*
acc2
# SiLU(w1*x) * (w2*x)
# elif ACTIVATION == "gelu":
# # GELU 近似
# # GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
# gelu_approx = 0.5 * acc1 * (1 + tl.tanh(0.79788456 * (acc1 + 0.044715 * acc1 * acc1 * acc1)))
# out = gelu_approx * acc2
# else:
# # 无激活函数
# out = acc1 * acc2
# 存储结果
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
.
to
(
tl
.
bfloat16
),
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
def
fused_gated_proj
(
x
,
w1
,
w2
,
activation
=
"silu"
):
"""
x: [M, K] - input
w1: [N, K] - weight1 (PyTorch Linear weight, 形状为 [out_features, in_features])
w2: [N, K] - weight2 (PyTorch Linear weight, 形状为 [out_features, in_features])
返回: [M, N]
计算: activation(w1 @ x^T)^T * (w2 @ x^T)^T
等价于: SiLU(x @ w1^T) * (x @ w2^T)
"""
assert
x
.
dtype
==
torch
.
bfloat16
assert
w1
.
dtype
==
torch
.
bfloat16
and
w2
.
dtype
==
torch
.
bfloat16
M
,
K
=
x
.
shape
# M=1, K=4096
N
,
K2
=
w1
.
shape
# N=4096 K2=11264
assert
K
==
K2
,
f
"Dimension mismatch: x K=
{
K
}
, w1 K=
{
K2
}
"
assert
w2
.
shape
==
(
N
,
K
),
f
"w2 shape mismatch: expected
{
(
N
,
K
)
}
, got
{
w2
.
shape
}
"
# 提前转置权重到 [K, N] 格式
w1_t
=
w1
.
t
().
contiguous
()
# [K, N]
w2_t
=
w2
.
t
().
contiguous
()
# [K, N]
out
=
torch
.
empty
(
M
,
N
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
])
)
gated_proj_kernel
[
grid
](
x
,
w1_t
,
w2_t
,
out
,
# 传入转置后的权重
M
,
K
,
N
,
x
.
stride
(
0
),
x
.
stride
(
1
),
w1_t
.
stride
(
0
),
w1_t
.
stride
(
1
),
# [K, N] 的 stride
w2_t
.
stride
(
0
),
w2_t
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
ACTIVATION
=
activation
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
BLOCK_K
=
32
,
)
return
out
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
F
.
silu
self
.
act_type
=
"silu"
self
.
l1
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
11264
,
out_features
=
4096
,
bias
=
False
,
)
def
forward_org
(
self
,
z
):
"""原始实现"""
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
# PyTorch: F.linear(x, weight) = x @ weight^T
# z1 = F.linear(z_flat, self.l1.weight) # [M, N]
# z2 = F.linear(z_flat, self.l2.weight) # [M, N]
z1
,
z2
=
self
.
l1
(
z_flat
),
self
.
l2
(
z_flat
)
gated
=
self
.
act
(
z1
)
*
z2
return
gated
def
forward_opt
(
self
,
z
):
"""Triton优化实现"""
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
# Triton 路径
gated
=
fused_gated_proj
(
z_flat
,
self
.
l1
.
weight
,
# [N, K]
self
.
l2
.
weight
,
# [N, K]
activation
=
self
.
act_type
)
return
gated
if
__name__
==
"__main__"
:
seed
=
1111
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
# 创建模型实例
model
=
ParallelGatedMLP
()
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
# 测试不同的batch size
for
batch_size
in
[
1
]:
print
(
f
"
\n
{
'='
*
50
}
"
)
print
(
f
"Testing with batch_size=
{
batch_size
}
"
)
print
(
'='
*
50
)
x
=
torch
.
randn
(
batch_size
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
with
torch
.
no_grad
():
# 预热
for
_
in
range
(
3
):
_
=
model
.
forward_org
(
x
)
_
=
model
.
forward_opt
(
x
)
t0
=
time
.
time
()
# 计算原始版本
for
i
in
range
(
10
):
result_org
=
model
.
forward_org
(
x
)
t1
=
time
.
time
()
print
(
f
"Time taken for forward_org:
{
t1
-
t0
:
5
f
}
seconds"
)
# 计算优化版本
for
i
in
range
(
10
):
result_opt
=
model
.
forward_opt
(
x
)
print
(
f
"Time taken for forward_opt:
{
time
.
time
()
-
t1
:.
5
f
}
seconds"
)
# # 验证结果
# print(f"ORG shape: {result_org.shape}")
# print(f"OPT shape: {result_opt.shape}")
# # 计算差异
# diff = torch.abs(result_org - result_opt)
# print(f"Max diff: {diff.max().item():.6f}")
# print(f"Mean diff: {diff.mean().item():.6f}")
# print(f"Min diff: {diff.min().item():.6f}")
# # 相对误差
# rel_error = diff / (torch.abs(result_org) + 1e-8)
# print(f"Max relative error: {rel_error.max().item():.6f}")
# print(f"Mean relative error: {rel_error.mean().item():.6f}")
# # 验证前几个值
# print("\nFirst 10 values comparison:")
# print(f"ORG: {result_org[0, :10].float().cpu().numpy()}")
# print(f"OPT: {result_opt[0, :10].float().cpu().numpy()}")
# print(f"Diff: {diff[0, :10].float().cpu().numpy()}")
# # 检查是否匹配
# if torch.allclose(result_org, result_opt, rtol=1e-2, atol=1e-3):
# print("✓ Results match within tolerance!")
# else:
# print("✗ Results do not match!")
# # 额外的验证:检查数学等价性
# print(f"\n{'='*50}")
# print("Additional mathematical verification")
# print('='*50)
# # 使用小矩阵验证
# test_x = torch.randn(2, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w1 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w2 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# # PyTorch 计算
# z1_pt = F.linear(test_x, test_w1) # x @ w1^T
# z2_pt = F.linear(test_x, test_w2) # x @ w2^T
# result_pt = F.silu(z1_pt) * z2_pt
# # Triton 计算
# result_triton = fused_gated_proj(test_x, test_w1, test_w2, activation="silu")
# diff_test = torch.abs(result_pt - result_triton)
# print(f"Test max diff: {diff_test.max().item():.6f}")
# print(f"Test mean diff: {diff_test.mean().item():.6f}")
# if torch.allclose(result_pt, result_triton, rtol=1e-2, atol=1e-3):
# print("✓ Test passed: Triton implementation matches PyTorch!")
# else:
# print("✗ Test failed: Triton implementation doesn't match PyTorch!")
\ No newline at end of file
triton/oligoformer-opt/org_code/case-0.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
time
import
torch
import
numpy
as
np
import
random
@
triton
.
jit
def
gated_proj_kernel
(
x_ptr
,
w1_ptr
,
w2_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_wk
,
stride_wn
,
# w is [N, K], so stride_wn = K
stride_om
,
stride_on
,
ACTIVATION
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
w1_ptrs
=
w1_ptr
+
offs_n
[:,
None
]
*
stride_wn
+
offs_k
[
None
,
:]
*
stride_wk
w2_ptrs
=
w2_ptr
+
offs_n
[:,
None
]
*
stride_wn
+
offs_k
[
None
,
:]
*
stride_wk
acc1
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
acc2
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
k_mask
=
offs_k
[
None
,
:]
<
K
-
k
x
=
tl
.
load
(
x_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
M
)
&
k_mask
,
other
=
0.0
)
w1
=
tl
.
load
(
w1_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
N
)
&
k_mask
,
other
=
0.0
)
w2
=
tl
.
load
(
w2_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
N
)
&
k_mask
,
other
=
0.0
)
acc1
+=
tl
.
dot
(
x
,
w1
.
T
)
acc2
+=
tl
.
dot
(
x
,
w2
.
T
)
x_ptrs
+=
BLOCK_K
*
stride_xk
w1_ptrs
+=
BLOCK_K
*
stride_wk
w2_ptrs
+=
BLOCK_K
*
stride_wk
offs_k
+=
BLOCK_K
z1
=
acc1
.
to
(
tl
.
float32
)
z2
=
acc2
.
to
(
tl
.
float32
)
if
ACTIVATION
==
"silu"
:
sig
=
tl
.
sigmoid
(
z1
)
out
=
z1
*
sig
*
z2
elif
ACTIVATION
==
"gelu"
:
# Triton 没有 gelu,可近似或回退
# out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
sig
=
tl
.
sigmoid
(
z1
)
out
=
z1
*
sig
*
z2
else
:
out
=
z1
*
z2
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
.
to
(
tl
.
bfloat16
),
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
def
fused_gated_proj
(
x
,
w1
,
w2
,
activation
=
"silu"
):
assert
x
.
dtype
==
torch
.
bfloat16
assert
w1
.
dtype
==
torch
.
bfloat16
and
w2
.
dtype
==
torch
.
bfloat16
M
,
K
=
x
.
shape
# 1, 4096
N
,
_
=
w1
.
shape
# 4096, 11264
assert
w2
.
shape
==
(
N
,
K
)
out
=
torch
.
empty
(
M
,
N
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
])
)
gated_proj_kernel
[
grid
](
x
,
w1
,
w2
,
out
,
M
,
K
,
N
,
x
.
stride
(
0
),
x
.
stride
(
1
),
w1
.
stride
(
1
),
w1
.
stride
(
0
),
out
.
stride
(
0
),
out
.
stride
(
1
),
ACTIVATION
=
activation
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
BLOCK_K
=
32
,
)
return
out
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
F
.
silu
self
.
act_type
=
"silu"
self
.
l1
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
11264
,
out_features
=
4096
,
bias
=
False
,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self
.
l1
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l1
.
weight
.
contiguous
())
self
.
l2
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l2
.
weight
.
contiguous
())
self
.
l3
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l3
.
weight
.
contiguous
())
def
forward
(
self
,
z
):
# z: [B, S, D] → flatten to [M, D]
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
int
(
shape
[
-
1
]))
# [M, D]
# Triton 路径
gated
=
fused_gated_proj
(
z_flat
,
self
.
l1
.
weight
,
# [inner, hidden]
self
.
l2
.
weight
,
activation
=
self
.
act_type
)
# y_flat = self.l3(gated) # [M, D]
# y = y_flat.view(*shape)
return
gated
def
forward_org
(
self
,
z
):
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# GELU 或调试时走原生路径
z1
,
z2
=
self
.
l1
(
z_flat
),
self
.
l2
(
z_flat
)
gated
=
self
.
act
(
z1
)
*
z2
return
gated
def
forward_opt
(
self
,
z
):
# z: [B, S, D] → flatten to [M, D]
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
int
(
shape
[
-
1
]))
# [M, D]
# Triton 路径
gated
=
fused_gated_proj
(
z_flat
,
self
.
l1
.
weight
,
# [inner, hidden]
self
.
l2
.
weight
,
activation
=
self
.
act_type
)
return
gated
if
__name__
==
"__main__"
:
seed
=
1111
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
# if using multi-GPU
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
# 可选:牺牲性能以换取可复现性(因为某些 CUDA 操作是非确定性的)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
# 创建模型实例
model
=
ParallelGatedMLP
()
# 将模型转换为 bfloat16
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device
=
"cuda:0"
# 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x
=
torch
.
randn
(
1
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
with
torch
.
no_grad
():
result_org
=
model
.
forward_org
(
x
)
print
(
f
"ORG:
{
result_org
[
0
,
:
20
]
}
"
)
result_opt
=
model
.
forward_opt
(
x
)
print
(
f
"OPT:
{
result_opt
[
0
,
:
20
]
}
"
)
\ No newline at end of file
triton/oligoformer-opt/org_code/demo.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
# 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回)
def
grab_first_if_tuple
(
x
):
return
x
[
0
]
if
isinstance
(
x
,
tuple
)
else
x
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
):
super
().
__init__
()
self
.
layer_idx
=
layer_idx
multiple_of
=
config
.
get
(
"inner_size_multiple_of"
,
64
)
self
.
act_type
=
config
.
get
(
"mlp_activation"
,
"gelu"
)
if
self
.
act_type
==
"gelu"
:
self
.
act
=
F
.
gelu
elif
self
.
act_type
==
"silu"
:
self
.
act
=
F
.
silu
else
:
raise
NotImplementedError
if
self
.
layer_idx
>
0
and
config
.
get
(
"evo2_style_activations"
,
False
):
self
.
act
=
nn
.
Identity
()
inner_size
=
11264
self
.
l1
=
nn
.
Linear
(
in_features
=
config
.
get
(
"hidden_size"
,
4096
),
out_features
=
inner_size
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
config
.
get
(
"hidden_size"
,
4096
),
out_features
=
inner_size
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
inner_size
,
out_features
=
config
.
get
(
"hidden_size"
,
4096
),
bias
=
False
,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self
.
l1
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l1
.
weight
.
contiguous
())
self
.
l2
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l2
.
weight
.
contiguous
())
self
.
l3
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l3
.
weight
.
contiguous
())
def
forward
(
self
,
z
):
z1
,
z2
=
self
.
l1
(
z
),
self
.
l2
(
z
)
return
z1
,
z2
# === 示例调用 ===
if
__name__
==
"__main__"
:
# 模拟配置
config
=
{
"hidden_size"
:
4096
,
"mlp_activation"
:
"silu"
,
"model_parallel_size:q"
:
1
,
"evo2_style_activations"
:
False
,
}
layer_idx
=
0
# 创建模型实例
model
=
ParallelGatedMLP
(
config
,
layer_idx
)
# 将模型转换为 bfloat16
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device
=
"cuda:0"
# 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x
=
torch
.
randn
(
1
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
# 推理
with
torch
.
no_grad
():
for
i
in
range
(
10
):
output
=
model
(
x
)
\ No newline at end of file
triton/oligoformer-opt/org_code/matmul-sample.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
numpy
as
np
import
random
import
time
@
triton
.
jit
def
matmul_kernel
(
x_ptr
,
w_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_wk
,
stride_wn
,
# w is [K, N] (已经转置好)
stride_om
,
stride_on
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
w_ptrs
=
w_ptr
+
offs_k
[:,
None
]
*
stride_wk
+
offs_n
[
None
,
:]
*
stride_wn
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
x_mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_k
[
None
,
:]
<
K
-
k
)
w_mask
=
(
offs_k
[:,
None
]
<
K
-
k
)
&
(
offs_n
[
None
,
:]
<
N
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
x_mask
,
other
=
0.0
)
w
=
tl
.
load
(
w_ptrs
,
mask
=
w_mask
,
other
=
0.0
)
acc
+=
tl
.
dot
(
x
,
w
)
x_ptrs
+=
BLOCK_K
*
stride_xk
w_ptrs
+=
BLOCK_K
*
stride_wk
# 转换为bfloat16输出
out
=
acc
.
to
(
tl
.
bfloat16
)
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
,
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
def
triton_matmul
(
x
,
weight
):
"""
Compute y = x @ weight.T using Triton.
x: [M, K], dtype=bfloat16
weight: [N, K], dtype=bfloat16 (PyTorch Linear weight, 形状是[out_features, in_features])
Returns: y: [M, N], dtype=bfloat16
"""
assert
x
.
dtype
==
torch
.
bfloat16
assert
weight
.
dtype
==
torch
.
bfloat16
assert
x
.
device
==
weight
.
device
assert
x
.
is_contiguous
()
M
,
K
=
x
.
shape
N
,
K2
=
weight
.
shape
assert
K
==
K2
,
f
"K mismatch:
{
K
}
!=
{
K2
}
"
# 提前转置权重到[K, N]格式,这样triton kernel可以直接使用
# weight是[N, K],我们需要weight.T = [K, N]
w_t
=
weight
.
t
().
contiguous
()
# [K, N]
out
=
torch
.
empty
(
M
,
N
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
])
)
# 注意:这里传递的是转置后的权重w_t,形状是[K, N]
matmul_kernel
[
grid
](
x
,
w_t
,
out
,
M
,
K
,
N
,
x
.
stride
(
0
),
x
.
stride
(
1
),
w_t
.
stride
(
0
),
w_t
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
BLOCK_M
=
64
,
BLOCK_N
=
64
,
BLOCK_K
=
32
,
)
return
out
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
F
.
silu
self
.
act_type
=
"silu"
self
.
l1
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
11264
,
out_features
=
4096
,
bias
=
False
,
)
def
forward_org
(
self
,
z
):
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
y
=
F
.
linear
(
z_flat
,
self
.
l1
.
weight
,
bias
=
None
)
# [M, N]
return
y
def
forward_org_triton
(
self
,
z
):
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
y
=
triton_matmul
(
z_flat
,
self
.
l1
.
weight
)
# [M, N]
return
y
if
__name__
==
"__main__"
:
seed
=
1111
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
model
=
ParallelGatedMLP
()
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
device
=
"cuda:0"
x
=
torch
.
randn
(
1
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
# 测试正确性
with
torch
.
no_grad
():
result_org
=
model
.
forward_org
(
x
)
result_opt
=
model
.
forward_org_triton
(
x
)
print
(
f
"ORG shape:
{
result_org
.
shape
}
"
)
print
(
f
"OPT shape:
{
result_opt
.
shape
}
"
)
# 打印前20个元素比较
print
(
f
"
\n
ORG first 20:
{
result_org
[
0
,
:
20
]
}
"
)
print
(
f
"OPT first 20:
{
result_opt
[
0
,
:
20
]
}
"
)
# 计算差异
diff
=
torch
.
abs
(
result_org
-
result_opt
)
print
(
f
"
\n
Max diff:
{
diff
.
max
().
item
()
}
"
)
print
(
f
"Mean diff:
{
diff
.
mean
().
item
()
}
"
)
# 相对误差检查
rel_error
=
diff
/
(
torch
.
abs
(
result_org
)
+
1e-8
)
print
(
f
"Max relative error:
{
rel_error
.
max
().
item
()
}
"
)
# 验证是否在合理误差范围内(由于浮点计算差异)
if
torch
.
allclose
(
result_org
,
result_opt
,
rtol
=
1e-2
,
atol
=
1e-3
):
print
(
"
\n
✓ Results match within tolerance!"
)
else
:
print
(
"
\n
✗ Results do not match!"
)
\ No newline at end of file
triton/oligoformer-opt/org_code/run.sh
0 → 100644
View file @
c601083d
export
ROCBLAS_LAYER
=
3
python trition_opt.py
\ No newline at end of file
triton/oligoformer-opt/org_code/samples/clean_hipprof.sh
0 → 100644
View file @
c601083d
rm
-rf
*
.db
rm
-rf
*
.csv
rm
-rf
*
.txt
rm
-rf
*
.json
\ No newline at end of file
triton/oligoformer-opt/org_code/samples/matmul-sample.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
numpy
as
np
import
random
import
time
@
triton
.
jit
def
matmul_kernel
(
x_ptr
,
w_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_wk
,
stride_wn
,
# w is [K, N] (已经转置好)
stride_om
,
stride_on
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
w_ptrs
=
w_ptr
+
offs_k
[:,
None
]
*
stride_wk
+
offs_n
[
None
,
:]
*
stride_wn
acc
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
x_mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_k
[
None
,
:]
<
K
-
k
)
w_mask
=
(
offs_k
[:,
None
]
<
K
-
k
)
&
(
offs_n
[
None
,
:]
<
N
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
x_mask
,
other
=
0.0
)
w
=
tl
.
load
(
w_ptrs
,
mask
=
w_mask
,
other
=
0.0
)
acc
+=
tl
.
dot
(
x
,
w
)
x_ptrs
+=
BLOCK_K
*
stride_xk
w_ptrs
+=
BLOCK_K
*
stride_wk
# 转换为bfloat16输出
out
=
acc
.
to
(
tl
.
bfloat16
)
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
,
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
def
triton_matmul
(
x
,
weight
):
"""
Compute y = x @ weight.T using Triton.
x: [M, K], dtype=bfloat16
weight: [N, K], dtype=bfloat16 (PyTorch Linear weight, 形状是[out_features, in_features])
Returns: y: [M, N], dtype=bfloat16
"""
assert
x
.
dtype
==
torch
.
bfloat16
assert
weight
.
dtype
==
torch
.
bfloat16
assert
x
.
device
==
weight
.
device
assert
x
.
is_contiguous
()
M
,
K
=
x
.
shape
N
,
K2
=
weight
.
shape
assert
K
==
K2
,
f
"K mismatch:
{
K
}
!=
{
K2
}
"
# 提前转置权重到[K, N]格式,这样triton kernel可以直接使用
# weight是[N, K],我们需要weight.T = [K, N]
w_t
=
weight
.
t
().
contiguous
()
# [K, N]
out
=
torch
.
empty
(
M
,
N
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
])
)
# 注意:这里传递的是转置后的权重w_t,形状是[K, N]
matmul_kernel
[
grid
](
x
,
w_t
,
out
,
M
,
K
,
N
,
x
.
stride
(
0
),
x
.
stride
(
1
),
w_t
.
stride
(
0
),
w_t
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
BLOCK_M
=
64
,
BLOCK_N
=
64
,
BLOCK_K
=
32
,
)
return
out
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
F
.
silu
self
.
act_type
=
"silu"
self
.
l1
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
11264
,
out_features
=
4096
,
bias
=
False
,
)
def
forward_org
(
self
,
z
):
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
# bfloat16 数据精度
# self.l1 = nn.Linear(
# in_features=4096,
# out_features=11264,
# bias=False,
# )
# z_flat.shape 1,4096
y
=
F
.
linear
(
z_flat
,
self
.
l1
.
weight
,
bias
=
None
)
# [M, N]
return
y
def
forward_org_triton
(
self
,
z
):
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
y
=
triton_matmul
(
z_flat
,
self
.
l1
.
weight
)
# [M, N]
return
y
if
__name__
==
"__main__"
:
seed
=
1111
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
model
=
ParallelGatedMLP
()
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
device
=
"cuda:0"
x
=
torch
.
randn
(
1
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
# 测试正确性
with
torch
.
no_grad
():
result_org
=
model
.
forward_org
(
x
)
result_opt
=
model
.
forward_org_triton
(
x
)
print
(
f
"ORG shape:
{
result_org
.
shape
}
"
)
print
(
f
"OPT shape:
{
result_opt
.
shape
}
"
)
# 打印前20个元素比较
print
(
f
"
\n
ORG first 20:
{
result_org
[
0
,
:
20
]
}
"
)
print
(
f
"OPT first 20:
{
result_opt
[
0
,
:
20
]
}
"
)
# 计算差异
diff
=
torch
.
abs
(
result_org
-
result_opt
)
print
(
f
"
\n
Max diff:
{
diff
.
max
().
item
()
}
"
)
print
(
f
"Mean diff:
{
diff
.
mean
().
item
()
}
"
)
# 相对误差检查
rel_error
=
diff
/
(
torch
.
abs
(
result_org
)
+
1e-8
)
print
(
f
"Max relative error:
{
rel_error
.
max
().
item
()
}
"
)
# 验证是否在合理误差范围内(由于浮点计算差异)
if
torch
.
allclose
(
result_org
,
result_opt
,
rtol
=
1e-2
,
atol
=
1e-3
):
print
(
"
\n
✓ Results match within tolerance!"
)
else
:
print
(
"
\n
✗ Results do not match!"
)
\ No newline at end of file
triton/oligoformer-opt/org_code/samples/mlp-sample.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
numpy
as
np
import
random
import
time
@
triton
.
jit
def
gated_proj_kernel
(
x_ptr
,
w1_ptr
,
w2_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_w1k
,
stride_w1n
,
# w1 is [K, N]
stride_w2k
,
stride_w2n
,
# w2 is [K, N]
stride_om
,
stride_on
,
ACTIVATION
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
# x: [M, K]
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
# w1 and w2: [K, N] (转置后的权重)
# 注意:w1_ptr 和 w2_ptr 已经指向转置后的权重
w1_ptrs
=
w1_ptr
+
offs_k
[:,
None
]
*
stride_w1k
+
offs_n
[
None
,
:]
*
stride_w1n
w2_ptrs
=
w2_ptr
+
offs_k
[:,
None
]
*
stride_w2k
+
offs_n
[
None
,
:]
*
stride_w2n
acc1
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
acc2
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
# 加载 x
x_mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_k
[
None
,
:]
<
K
-
k
)
x
=
tl
.
load
(
x_ptrs
,
mask
=
x_mask
,
other
=
0.0
)
# 加载 w1 和 w2
w_mask
=
(
offs_k
[:,
None
]
<
K
-
k
)
&
(
offs_n
[
None
,
:]
<
N
)
w1
=
tl
.
load
(
w1_ptrs
,
mask
=
w_mask
,
other
=
0.0
)
w2
=
tl
.
load
(
w2_ptrs
,
mask
=
w_mask
,
other
=
0.0
)
# 计算点积: x @ w1^T 和 x @ w2^T
# x: [BLOCK_M, BLOCK_K], w1: [BLOCK_K, BLOCK_N]
# tl.dot(x, w1) 计算的是 x @ w1,但我们需要 x @ w1^T
# 由于 w1 是转置后的权重 [K, N],所以 x @ w1 就是我们要的 x @ w1^T
acc1
+=
tl
.
dot
(
x
,
w1
)
acc2
+=
tl
.
dot
(
x
,
w2
)
# 移动指针到下一个block
x_ptrs
+=
BLOCK_K
*
stride_xk
w1_ptrs
+=
BLOCK_K
*
stride_w1k
w2_ptrs
+=
BLOCK_K
*
stride_w2k
# 应用激活函数
if
ACTIVATION
==
"silu"
:
# SiLU(x) = x * sigmoid(x)
sig
=
tl
.
sigmoid
(
acc1
)
out
=
acc1
*
sig
*
acc2
# SiLU(w1*x) * (w2*x)
# elif ACTIVATION == "gelu":
# # GELU 近似
# # GELU(x) ≈ 0.5 * x * (1 + tanh(sqrt(2/π) * (x + 0.044715 * x^3)))
# gelu_approx = 0.5 * acc1 * (1 + tl.tanh(0.79788456 * (acc1 + 0.044715 * acc1 * acc1 * acc1)))
# out = gelu_approx * acc2
# else:
# # 无激活函数
# out = acc1 * acc2
# 存储结果
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
.
to
(
tl
.
bfloat16
),
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
def
fused_gated_proj
(
x
,
w1
,
w2
,
activation
=
"silu"
):
"""
x: [M, K] - input
w1: [N, K] - weight1 (PyTorch Linear weight, 形状为 [out_features, in_features])
w2: [N, K] - weight2 (PyTorch Linear weight, 形状为 [out_features, in_features])
返回: [M, N]
计算: activation(w1 @ x^T)^T * (w2 @ x^T)^T
等价于: SiLU(x @ w1^T) * (x @ w2^T)
"""
assert
x
.
dtype
==
torch
.
bfloat16
assert
w1
.
dtype
==
torch
.
bfloat16
and
w2
.
dtype
==
torch
.
bfloat16
M
,
K
=
x
.
shape
# M=1, K=4096
N
,
K2
=
w1
.
shape
# N=4096 K2=11264
assert
K
==
K2
,
f
"Dimension mismatch: x K=
{
K
}
, w1 K=
{
K2
}
"
assert
w2
.
shape
==
(
N
,
K
),
f
"w2 shape mismatch: expected
{
(
N
,
K
)
}
, got
{
w2
.
shape
}
"
# 提前转置权重到 [K, N] 格式
w1_t
=
w1
.
t
().
contiguous
()
# [K, N]
w2_t
=
w2
.
t
().
contiguous
()
# [K, N]
out
=
torch
.
empty
(
M
,
N
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
])
)
gated_proj_kernel
[
grid
](
x
,
w1_t
,
w2_t
,
out
,
# 传入转置后的权重
M
,
K
,
N
,
x
.
stride
(
0
),
x
.
stride
(
1
),
w1_t
.
stride
(
0
),
w1_t
.
stride
(
1
),
# [K, N] 的 stride
w2_t
.
stride
(
0
),
w2_t
.
stride
(
1
),
out
.
stride
(
0
),
out
.
stride
(
1
),
ACTIVATION
=
activation
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
BLOCK_K
=
32
,
)
return
out
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
):
super
().
__init__
()
self
.
act
=
F
.
silu
self
.
act_type
=
"silu"
self
.
l1
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
4096
,
out_features
=
11264
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
11264
,
out_features
=
4096
,
bias
=
False
,
)
def
forward_org
(
self
,
z
):
"""原始实现"""
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
# PyTorch: F.linear(x, weight) = x @ weight^T
# z1 = F.linear(z_flat, self.l1.weight) # [M, N]
# z2 = F.linear(z_flat, self.l2.weight) # [M, N]
z1
,
z2
=
self
.
l1
(
z_flat
),
self
.
l2
(
z_flat
)
gated
=
self
.
act
(
z1
)
*
z2
return
gated
def
forward_opt
(
self
,
z
):
"""Triton优化实现"""
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# [M, K]
# Triton 路径
gated
=
fused_gated_proj
(
z_flat
,
self
.
l1
.
weight
,
# [N, K]
self
.
l2
.
weight
,
# [N, K]
activation
=
self
.
act_type
)
return
gated
if
__name__
==
"__main__"
:
seed
=
1111
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
# 创建模型实例
model
=
ParallelGatedMLP
()
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
# 测试不同的batch size
for
batch_size
in
[
1
]:
print
(
f
"
\n
{
'='
*
50
}
"
)
print
(
f
"Testing with batch_size=
{
batch_size
}
"
)
print
(
'='
*
50
)
x
=
torch
.
randn
(
batch_size
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
with
torch
.
no_grad
():
# 预热
for
_
in
range
(
3
):
_
=
model
.
forward_org
(
x
)
_
=
model
.
forward_opt
(
x
)
t0
=
time
.
time
()
# 计算原始版本
result_org
=
model
.
forward_org
(
x
)
t1
=
time
.
time
()
print
(
f
"Time taken for forward_org:
{
t1
-
t0
:.
4
f
}
seconds"
)
# 计算优化版本
result_opt
=
model
.
forward_opt
(
x
)
print
(
f
"Time taken for forward_org:
{
time
.
time
()
-
t1
:.
4
f
}
seconds"
)
# 验证结果
print
(
f
"ORG shape:
{
result_org
.
shape
}
"
)
print
(
f
"OPT shape:
{
result_opt
.
shape
}
"
)
# 计算差异
diff
=
torch
.
abs
(
result_org
-
result_opt
)
print
(
f
"Max diff:
{
diff
.
max
().
item
():.
6
f
}
"
)
print
(
f
"Mean diff:
{
diff
.
mean
().
item
():.
6
f
}
"
)
print
(
f
"Min diff:
{
diff
.
min
().
item
():.
6
f
}
"
)
# 相对误差
rel_error
=
diff
/
(
torch
.
abs
(
result_org
)
+
1e-8
)
print
(
f
"Max relative error:
{
rel_error
.
max
().
item
():.
6
f
}
"
)
print
(
f
"Mean relative error:
{
rel_error
.
mean
().
item
():.
6
f
}
"
)
# 验证前几个值
print
(
"
\n
First 10 values comparison:"
)
print
(
f
"ORG:
{
result_org
[
0
,
:
10
].
float
().
cpu
().
numpy
()
}
"
)
print
(
f
"OPT:
{
result_opt
[
0
,
:
10
].
float
().
cpu
().
numpy
()
}
"
)
print
(
f
"Diff:
{
diff
[
0
,
:
10
].
float
().
cpu
().
numpy
()
}
"
)
# 检查是否匹配
if
torch
.
allclose
(
result_org
,
result_opt
,
rtol
=
1e-2
,
atol
=
1e-3
):
print
(
"✓ Results match within tolerance!"
)
else
:
print
(
"✗ Results do not match!"
)
# # 额外的验证:检查数学等价性
# print(f"\n{'='*50}")
# print("Additional mathematical verification")
# print('='*50)
# # 使用小矩阵验证
# test_x = torch.randn(2, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w1 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# test_w2 = torch.randn(32, 16, dtype=torch.bfloat16, device="cuda:0")
# # PyTorch 计算
# z1_pt = F.linear(test_x, test_w1) # x @ w1^T
# z2_pt = F.linear(test_x, test_w2) # x @ w2^T
# result_pt = F.silu(z1_pt) * z2_pt
# # Triton 计算
# result_triton = fused_gated_proj(test_x, test_w1, test_w2, activation="silu")
# diff_test = torch.abs(result_pt - result_triton)
# print(f"Test max diff: {diff_test.max().item():.6f}")
# print(f"Test mean diff: {diff_test.mean().item():.6f}")
# if torch.allclose(result_pt, result_triton, rtol=1e-2, atol=1e-3):
# print("✓ Test passed: Triton implementation matches PyTorch!")
# else:
# print("✗ Test failed: Triton implementation doesn't match PyTorch!")
\ No newline at end of file
triton/oligoformer-opt/org_code/trition_opt.py
0 → 100644
View file @
c601083d
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
triton
import
triton.language
as
tl
import
time
import
torch
import
numpy
as
np
import
random
seed
=
1111
torch
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed
(
seed
)
torch
.
cuda
.
manual_seed_all
(
seed
)
# if using multi-GPU
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
# 可选:牺牲性能以换取可复现性(因为某些 CUDA 操作是非确定性的)
torch
.
backends
.
cudnn
.
deterministic
=
True
torch
.
backends
.
cudnn
.
benchmark
=
False
@
triton
.
jit
def
gated_proj_kernel
(
x_ptr
,
w1_ptr
,
w2_ptr
,
out_ptr
,
M
,
K
,
N
,
stride_xm
,
stride_xk
,
stride_wk
,
stride_wn
,
# w is [N, K], so stride_wn = K
stride_om
,
stride_on
,
ACTIVATION
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
=
64
,
BLOCK_N
:
tl
.
constexpr
=
64
,
BLOCK_K
:
tl
.
constexpr
=
32
,
):
pid_m
=
tl
.
program_id
(
0
)
pid_n
=
tl
.
program_id
(
1
)
offs_m
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
offs_n
=
pid_n
*
BLOCK_N
+
tl
.
arange
(
0
,
BLOCK_N
)
offs_k
=
tl
.
arange
(
0
,
BLOCK_K
)
x_ptrs
=
x_ptr
+
offs_m
[:,
None
]
*
stride_xm
+
offs_k
[
None
,
:]
*
stride_xk
w1_ptrs
=
w1_ptr
+
offs_n
[:,
None
]
*
stride_wn
+
offs_k
[
None
,
:]
*
stride_wk
w2_ptrs
=
w2_ptr
+
offs_n
[:,
None
]
*
stride_wn
+
offs_k
[
None
,
:]
*
stride_wk
acc1
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
acc2
=
tl
.
zeros
((
BLOCK_M
,
BLOCK_N
),
dtype
=
tl
.
float32
)
for
k
in
range
(
0
,
K
,
BLOCK_K
):
k_mask
=
offs_k
[
None
,
:]
<
K
-
k
x
=
tl
.
load
(
x_ptrs
,
mask
=
(
offs_m
[:,
None
]
<
M
)
&
k_mask
,
other
=
0.0
)
w1
=
tl
.
load
(
w1_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
N
)
&
k_mask
,
other
=
0.0
)
w2
=
tl
.
load
(
w2_ptrs
,
mask
=
(
offs_n
[:,
None
]
<
N
)
&
k_mask
,
other
=
0.0
)
acc1
+=
tl
.
dot
(
x
,
w1
.
T
)
acc2
+=
tl
.
dot
(
x
,
w2
.
T
)
x_ptrs
+=
BLOCK_K
*
stride_xk
w1_ptrs
+=
BLOCK_K
*
stride_wk
w2_ptrs
+=
BLOCK_K
*
stride_wk
offs_k
+=
BLOCK_K
z1
=
acc1
.
to
(
tl
.
float32
)
z2
=
acc2
.
to
(
tl
.
float32
)
if
ACTIVATION
==
"silu"
:
sig
=
tl
.
sigmoid
(
z1
)
out
=
z1
*
sig
*
z2
elif
ACTIVATION
==
"gelu"
:
# Triton 没有 gelu,可近似或回退
# out = z1 * 0.5 * (1 + tl.tanh(0.79788456 * (z1 + 0.044715 * z1 * z1 * z1))) * z2
sig
=
tl
.
sigmoid
(
z1
)
out
=
z1
*
sig
*
z2
else
:
out
=
z1
*
z2
out_ptrs
=
out_ptr
+
offs_m
[:,
None
]
*
stride_om
+
offs_n
[
None
,
:]
*
stride_on
tl
.
store
(
out_ptrs
,
out
.
to
(
tl
.
bfloat16
),
mask
=
(
offs_m
[:,
None
]
<
M
)
&
(
offs_n
[
None
,
:]
<
N
))
def
fused_gated_proj
(
x
,
w1
,
w2
,
activation
=
"silu"
):
assert
x
.
dtype
==
torch
.
bfloat16
assert
w1
.
dtype
==
torch
.
bfloat16
and
w2
.
dtype
==
torch
.
bfloat16
M
,
K
=
x
.
shape
N
,
_
=
w1
.
shape
assert
w2
.
shape
==
(
N
,
K
)
out
=
torch
.
empty
(
M
,
N
,
dtype
=
torch
.
bfloat16
,
device
=
x
.
device
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
M
,
META
[
'BLOCK_M'
]),
triton
.
cdiv
(
N
,
META
[
'BLOCK_N'
])
)
gated_proj_kernel
[
grid
](
x
,
w1
,
w2
,
out
,
M
,
K
,
N
,
x
.
stride
(
0
),
x
.
stride
(
1
),
w1
.
stride
(
1
),
w1
.
stride
(
0
),
out
.
stride
(
0
),
out
.
stride
(
1
),
ACTIVATION
=
activation
,
BLOCK_M
=
64
,
BLOCK_N
=
64
,
BLOCK_K
=
32
,
)
return
out
# 模拟 grab_first_if_tuple(如果返回的是 tuple 取第一个,否则原样返回)
def
grab_first_if_tuple
(
x
):
return
x
[
0
]
if
isinstance
(
x
,
tuple
)
else
x
class
ParallelGatedMLP
(
nn
.
Module
):
def
__init__
(
self
,
config
,
layer_idx
):
super
().
__init__
()
self
.
layer_idx
=
layer_idx
multiple_of
=
config
.
get
(
"inner_size_multiple_of"
,
64
)
self
.
act_type
=
config
.
get
(
"mlp_activation"
,
"gelu"
)
if
self
.
act_type
==
"gelu"
:
self
.
act
=
F
.
gelu
elif
self
.
act_type
==
"silu"
:
self
.
act
=
F
.
silu
else
:
raise
NotImplementedError
self
.
act_type
=
"silu"
if
self
.
layer_idx
>
0
and
config
.
get
(
"evo2_style_activations"
,
False
):
self
.
act
=
nn
.
Identity
()
inner_size
=
11264
self
.
l1
=
nn
.
Linear
(
in_features
=
config
.
get
(
"hidden_size"
,
4096
),
out_features
=
inner_size
,
bias
=
False
,
)
self
.
l2
=
nn
.
Linear
(
in_features
=
config
.
get
(
"hidden_size"
,
4096
),
out_features
=
inner_size
,
bias
=
False
,
)
self
.
l3
=
nn
.
Linear
(
in_features
=
inner_size
,
out_features
=
config
.
get
(
"hidden_size"
,
4096
),
bias
=
False
,
)
# 确保权重是 contiguous(通常 Linear 默认就是,但保险起见)
self
.
l1
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l1
.
weight
.
contiguous
())
self
.
l2
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l2
.
weight
.
contiguous
())
self
.
l3
.
weight
=
torch
.
nn
.
Parameter
(
self
.
l3
.
weight
.
contiguous
())
def
forward
(
self
,
z
):
# z: [B, S, D] → flatten to [M, D]
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
int
(
shape
[
-
1
]))
# [M, D]
# Triton 路径
gated
=
fused_gated_proj
(
z_flat
,
self
.
l1
.
weight
,
# [inner, hidden]
self
.
l2
.
weight
,
activation
=
self
.
act_type
)
# y_flat = self.l3(gated) # [M, D]
# y = y_flat.view(*shape)
return
gated
def
forward_org
(
self
,
z
):
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
shape
[
-
1
])
# GELU 或调试时走原生路径
z1
,
z2
=
self
.
l1
(
z_flat
),
self
.
l2
(
z_flat
)
gated
=
self
.
act
(
z1
)
*
z2
return
gated
def
forward_opt
(
self
,
z
):
# z: [B, S, D] → flatten to [M, D]
shape
=
z
.
shape
z_flat
=
z
.
view
(
-
1
,
int
(
shape
[
-
1
]))
# [M, D]
# Triton 路径
gated
=
fused_gated_proj
(
z_flat
,
self
.
l1
.
weight
,
# [inner, hidden]
self
.
l2
.
weight
,
activation
=
self
.
act_type
)
return
gated
if
__name__
==
"__main__"
:
# 模拟配置
config
=
{
"hidden_size"
:
4096
,
"mlp_activation"
:
"silu"
,
"model_parallel_size:q"
:
1
,
"evo2_style_activations"
:
False
,
}
layer_idx
=
0
# 创建模型实例
model
=
ParallelGatedMLP
(
config
,
layer_idx
)
# 将模型转换为 bfloat16
model
=
model
.
to
(
dtype
=
torch
.
bfloat16
,
device
=
"cuda:0"
)
# 创建输入张量(batch=1, seq_len=1, hidden=4096)
device
=
"cuda:0"
# 或 "cuda" 如果你有支持 bf16 的 GPU(如 A100、H100)
x
=
torch
.
randn
(
1
,
1
,
4096
,
dtype
=
torch
.
bfloat16
,
device
=
device
)
with
torch
.
no_grad
():
result_org
=
model
.
forward_org
(
x
)
print
(
f
"ORG:
{
result_org
[
0
,
:
20
]
}
"
)
result_opt
=
model
.
forward_opt
(
x
)
print
(
f
"OPT:
{
result_opt
[
0
,
:
20
]
}
"
)
# 推理
# with torch.no_grad():
# for i in range(1000):
# output = model(x)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment