Commit a9f1b7af authored by sxtyzhangzk's avatar sxtyzhangzk Committed by Zhekai Zhang
Browse files

[major] move bf16 to compiler args; fp16 experiment

parent d02f26df
...@@ -185,6 +185,10 @@ public: ...@@ -185,6 +185,10 @@ public:
}); });
} }
void forceFP16Attention(bool enable) {
Attention::setForceFP16(net.get(), enable);
}
private: private:
void checkModel() { void checkModel() {
......
...@@ -26,6 +26,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { ...@@ -26,6 +26,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
.def("stopDebug", &QuantizedFluxModel::stopDebug) .def("stopDebug", &QuantizedFluxModel::stopDebug)
.def("getDebugResults", &QuantizedFluxModel::getDebugResults) .def("getDebugResults", &QuantizedFluxModel::getDebugResults)
.def("setLoraScale", &QuantizedFluxModel::setLoraScale) .def("setLoraScale", &QuantizedFluxModel::setLoraScale)
.def("forceFP16Attention", &QuantizedFluxModel::forceFP16Attention)
; ;
py::class_<QuantizedGEMM>(m, "QuantizedGEMM") py::class_<QuantizedGEMM>(m, "QuantizedGEMM")
// .def(torch::init<>()) // .def(torch::init<>())
......
...@@ -106,7 +106,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) { ...@@ -106,7 +106,7 @@ AdaLayerNormZero::Output AdaLayerNormZero::forward(Tensor x, Tensor emb) {
Attention::Attention(int num_heads, int dim_head, Device device) : Attention::Attention(int num_heads, int dim_head, Device device) :
num_heads(num_heads), dim_head(dim_head) num_heads(num_heads), dim_head(dim_head), force_fp16(false)
{ {
headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu()); headmask_type = Tensor::allocate({num_heads}, Tensor::INT32, Device::cpu());
for (int i = 0; i < num_heads; i++) { for (int i = 0; i < num_heads; i++) {
...@@ -116,6 +116,8 @@ Attention::Attention(int num_heads, int dim_head, Device device) : ...@@ -116,6 +116,8 @@ Attention::Attention(int num_heads, int dim_head, Device device) :
} }
Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
const bool cast_fp16 = this->force_fp16 && qkv.scalar_type() != Tensor::FP16;
assert(qkv.ndims() == 3); assert(qkv.ndims() == 3);
const Device device = qkv.device(); const Device device = qkv.device();
...@@ -169,6 +171,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -169,6 +171,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
} }
} }
if (cast_fp16) {
Tensor tmp = Tensor::empty(qkv.shape.dataExtent, Tensor::FP16, qkv.device());
cast(qkv, tmp);
qkv = tmp;
}
debug("qkv", qkv);
Tensor cu_seqlens = cu_seqlens_cpu.copy(device); Tensor cu_seqlens = cu_seqlens_cpu.copy(device);
Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head}); Tensor reshaped = qkv.view({batch_size * num_tokens, num_heads * 3, dim_head});
...@@ -192,6 +202,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -192,6 +202,14 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
false, false, false, -1, -1 false, false, false, -1, -1
).front(); ).front();
debug("raw_attn_output", raw_attn_output);
if (cast_fp16) {
Tensor tmp = Tensor::empty(raw_attn_output.shape.dataExtent, Tensor::BF16, raw_attn_output.device());
cast(raw_attn_output, tmp);
raw_attn_output = tmp;
}
/** /**
Tensor raw_attn_output = mha_varlen_fwd(q, k, v, Tensor raw_attn_output = mha_varlen_fwd(q, k, v,
cu_seqlens, cu_seqlens,
...@@ -229,6 +247,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) { ...@@ -229,6 +247,16 @@ Tensor Attention::forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio) {
return raw_attn_output; return raw_attn_output;
} }
void Attention::setForceFP16(Module *module, bool value) {
spdlog::info("{} force fp16 attention", value ? "Enable" : "Disable");
module->traverse([&](Module *m) {
if (Attention *attn = dynamic_cast<Attention *>(m)) {
attn->force_fp16 = value;
}
});
}
FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device) : FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attention_heads, int attention_head_dim, int mlp_ratio, Tensor::ScalarType dtype, Device device) :
dim(dim), dim(dim),
dim_head(attention_head_dim / num_attention_heads), dim_head(attention_head_dim / num_attention_heads),
...@@ -250,6 +278,7 @@ FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attentio ...@@ -250,6 +278,7 @@ FluxSingleTransformerBlock::FluxSingleTransformerBlock(int dim, int num_attentio
(qkv_proj, "qkv_proj") (qkv_proj, "qkv_proj")
(norm_q, "norm_q") (norm_q, "norm_q")
(norm_k, "norm_k") (norm_k, "norm_k")
(attn, "attn")
(out_proj, "out_proj") (out_proj, "out_proj")
; ;
} }
...@@ -328,6 +357,7 @@ JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, i ...@@ -328,6 +357,7 @@ JointTransformerBlock::JointTransformerBlock(int dim, int num_attention_heads, i
(norm_k, "norm_k") (norm_k, "norm_k")
(norm_added_q, "norm_added_q") (norm_added_q, "norm_added_q")
(norm_added_k, "norm_added_k") (norm_added_k, "norm_added_k")
(attn, "attn")
(out_proj, "out_proj") (out_proj, "out_proj")
(out_proj_context, "out_proj_context") (out_proj_context, "out_proj_context")
(norm2, "norm2") (norm2, "norm2")
......
...@@ -53,16 +53,19 @@ private: ...@@ -53,16 +53,19 @@ private:
LayerNorm norm; LayerNorm norm;
}; };
class Attention { class Attention : public Module {
public: public:
static constexpr int POOL_SIZE = 128; static constexpr int POOL_SIZE = 128;
Attention(int num_heads, int dim_head, Device device); Attention(int num_heads, int dim_head, Device device);
Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio); Tensor forward(Tensor qkv, Tensor pool_qkv, float sparsityRatio);
static void setForceFP16(Module *module, bool value);
public: public:
const int num_heads; const int num_heads;
const int dim_head; const int dim_head;
bool force_fp16;
private: private:
Tensor cu_seqlens_cpu; Tensor cu_seqlens_cpu;
......
...@@ -27,7 +27,6 @@ ...@@ -27,7 +27,6 @@
#include "gemv_awq.h" #include "gemv_awq.h"
#include "../dispatch_utils.h" #include "../dispatch_utils.h"
#define ENABLE_BF16 1
#include "../utils.cuh" #include "../utils.cuh"
#include <cuda_fp16.h> #include <cuda_fp16.h>
......
...@@ -188,6 +188,28 @@ Tensor quant_static_fuse_gelu(Tensor x, float scale) { ...@@ -188,6 +188,28 @@ Tensor quant_static_fuse_gelu(Tensor x, float scale) {
return out; return out;
} }
void cast(Tensor input, Tensor output) {
assert(input.is_contiguous());
assert(output.is_contiguous());
assert(input.shape.dataExtent == output.shape.dataExtent);
auto stream = getCurrentCUDAStream();
dispatch(input.scalar_type(), [&]<typename input_t>() {
dispatch(output.scalar_type(), [&]<typename output_t>() {
constexpr int unroll = 16 / std::max(sizeof(input_t), sizeof(output_t));
int threadsPerBlock = 1024;
int blocksPerGrid = (int)ceilDiv<int64_t>(input.numel(), threadsPerBlock * unroll);
cast_kernel<input_t, output_t, unroll><<<blocksPerGrid, threadsPerBlock, 0, stream>>>(
input.data_ptr<input_t>(), output.data_ptr<output_t>(), input.numel());
checkCUDA(cudaGetLastError());
});
});
}
Tensor topk(Tensor x, int k) { Tensor topk(Tensor x, int k) {
constexpr int MAXK = 64 + 4; constexpr int MAXK = 64 + 4;
......
...@@ -11,6 +11,8 @@ void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v); ...@@ -11,6 +11,8 @@ void splitqkv(Tensor qkv, Tensor q, Tensor k, Tensor v);
Tensor quant_static(Tensor x, float scale); Tensor quant_static(Tensor x, float scale);
Tensor quant_static_fuse_gelu(Tensor x, float scale); Tensor quant_static_fuse_gelu(Tensor x, float scale);
void cast(Tensor input, Tensor output);
Tensor topk(Tensor x, int k); Tensor topk(Tensor x, int k);
template<size_t N> template<size_t N>
......
#include "reduction_utils.cuh" #include "reduction_utils.cuh"
#include <array> #include <array>
#include <cuda_fp16.h>
#include <cuda_bf16.h>
#include "utils.cuh" #include "utils.cuh"
#include "activation_kernels_impl.cuh" #include "activation_kernels_impl.cuh"
#include <cuda_fp16.h>
template<typename T> template<typename T>
__global__ void add_kernel(T *a, T *b, T *c, size_t length) { __global__ void add_kernel(T *a, T *b, T *c, size_t length) {
...@@ -162,7 +164,27 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output, ...@@ -162,7 +164,27 @@ __global__ void quant_kernel_static_fuse_gelu(const T * input, int8_t * output,
*reinterpret_cast<I8vec *>(&output[i]) = routput; *reinterpret_cast<I8vec *>(&output[i]) = routput;
} }
#include <cstdio> template<typename Tin, typename Tout, int unroll>
__global__ void cast_kernel(const Tin *input, Tout *output, size_t length) {
const int i = (blockIdx.x * blockDim.x + threadIdx.x) * unroll;
using Tvec_in = ::Tvec<Tin, unroll>;
using Tvec_out = ::Tvec<Tout, unroll>;
Tvec_in rinput = *reinterpret_cast<const Tvec_in *>(&input[i]);
Tvec_out routput;
#pragma unroll
for (int k = 0; k < unroll; k++) {
routput.data[k] = cuda_cast<Tout, Tin>(rinput.data[k]);
if constexpr (std::is_same_v<Tout, half>) {
routput.data[k] = __hmin(routput.data[k], (half)65504);
routput.data[k] = __hmax(routput.data[k], (half)-65504);
}
}
*reinterpret_cast<Tvec_out *>(&output[i]) = routput;
}
// input: [..., N] // input: [..., N]
// output: [..., K] of index in reverse order // output: [..., K] of index in reverse order
......
...@@ -8,6 +8,10 @@ ...@@ -8,6 +8,10 @@
#include <cuda_fp16.h> #include <cuda_fp16.h>
#ifdef ENABLE_BF16
#include <cuda_bf16.h>
#endif
template<typename T> struct num_elems; template<typename T> struct num_elems;
template <> struct num_elems<float> { static constexpr int value = 1; }; template <> struct num_elems<float> { static constexpr int value = 1; };
template <> struct num_elems<float2> { static constexpr int value = 2; }; template <> struct num_elems<float2> { static constexpr int value = 2; };
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment