#pragma once #include #include #include #include #include #include #include //#ifdef OLD_GENERATOR_PATH //#include //#else //#include //#endif namespace layer_norm { //////////////////////////////////////////////////////////////////////////////////////////////////// template struct LaunchParams{ size_t elts_per_thread; size_t workspace_bytes; size_t barrier_size; int multi_processor_count; cudaStream_t stream; Params params; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct ParamsBase { ParamsBase() : ctas_per_col(0) , rows(0) , cols(0) , x(nullptr) , mu(nullptr) , rs(nullptr) , gamma(nullptr) , gamma1(nullptr) , rowscale(nullptr) , colscale(nullptr) , dropout_keep_p(1.f) , dropout_scale(1.f) , is_rms_norm(false) , workspace(nullptr) , barrier(nullptr) { } // For Multi-CTA, number of different CTA groups. Otherwise same as gridDim.x. int ctas_per_col; // Input is interpreted as matrix. We normalize across columns. int rows; int cols; // Common data pointers. void *x0; void *x1; void *residual; void *x; void *dmask; void *dmask1; void *mu; void *rs; void *gamma; void *gamma1; void *rowscale; void *colscale; void *x0_subset; void *z_subset; float inverse_cols; float dropout_keep_p; float dropout_scale; float rowscale_const; bool is_rms_norm; // Multi-CTA workspace in gmem. void *workspace; // Multi-CTA sync barriers in gmem. int *barrier; }; //////////////////////////////////////////////////////////////////////////////////////////////////// struct FwdParams : public ParamsBase { FwdParams() : ParamsBase() , z(nullptr) , z1(nullptr) , beta(nullptr) , beta1(nullptr) , epsilon(0.f) { } // Output of LN FWD. void *z; void *z1; void *beta; void *beta1; float epsilon; // Random state. // at::PhiloxCudaState philox_args; }; //////////////////////////////////////////////////////////////////////////////////////////////////// using FwdFunction = std::function&, const bool)>; using FunctionKey = uint64_t; using FwdRegistry = std::unordered_map; extern FwdRegistry FWD_FUNCS; //////////////////////////////////////////////////////////////////////////////////////////////////// using fp32 = float; using fp16 = half; using bf16 = nv_bfloat16; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct TypeId{}; template<> struct TypeId{ constexpr static uint32_t Value = 0; }; template<> struct TypeId{ constexpr static uint32_t Value = 1; }; template<> struct TypeId{ constexpr static uint32_t Value = 2; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Type2Key{ constexpr static uint32_t Value = TypeId::Value << S; }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct WeightType2Key : public Type2Key{}; template struct InputType2Key : public Type2Key{}; template struct ResidualType2Key : public Type2Key{}; template struct OutputType2Key : public Type2Key{}; template struct ComputeType2Key : public Type2Key{}; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct Types2Key{ constexpr static uint32_t Value = WeightType2Key::Value | InputType2Key::Value | ResidualType2Key::Value | OutputType2Key::Value | ComputeType2Key::Value; constexpr static inline uint64_t get(const uint64_t hidden_size){ constexpr uint64_t type_key = Value; return (type_key << 32) | hidden_size; } }; //////////////////////////////////////////////////////////////////////////////////////////////////// template struct FwdRegistrar{ FwdRegistrar(FwdFunction f){ uint64_t key = Types2Key::get(HIDDEN_SIZE); FWD_FUNCS.insert({ key, f }); } }; //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace layer_norm