Unverified Commit 8dd12c87 authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

llama: update to commit e1e8e099 (#10513)

parent e6d2d041
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
......
...@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk ...@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)( typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, const int32_t * ids, void * vy,
const ggml_type type_x, cudaStream_t stream); ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_row_q8_1_cuda( void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, const int32_t * ids, void * vy,
const ggml_type type_x, cudaStream_t stream); ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_q8_1_cuda( void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, const int32_t * ids, void * vy,
const ggml_type type_x, cudaStream_t stream); ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
#pragma once
#include "common.cuh" #include "common.cuh"
#include <cstdint> #include <cstdint>
......
...@@ -5690,7 +5690,7 @@ kernel void kernel_flash_attn_ext( ...@@ -5690,7 +5690,7 @@ kernel void kernel_flash_attn_ext(
{ {
float S[Q] = { [0 ... Q-1] = 0.0f }; float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance // TODO: see if we can utilize quad-group functions for better performance
...@@ -5950,7 +5950,7 @@ kernel void kernel_flash_attn_ext( ...@@ -5950,7 +5950,7 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) { for (ushort sg = 1; sg < nsg; ++sg) {
float S = { 0.0f }; float S = { 0.0f };
float M = { -__FLT16_MAX__/2 }; float M = { -__FLT_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
...@@ -6197,7 +6197,7 @@ kernel void kernel_flash_attn_ext_vec( ...@@ -6197,7 +6197,7 @@ kernel void kernel_flash_attn_ext_vec(
{ {
float S = 0.0f; float S = 0.0f;
float M = -__FLT16_MAX__/2; float M = -__FLT_MAX__/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup
const short tx = tiisg%NL; const short tx = tiisg%NL;
......
...@@ -3237,7 +3237,7 @@ kernel void kernel_flash_attn_ext( ...@@ -3237,7 +3237,7 @@ kernel void kernel_flash_attn_ext(
{ {
float S[Q] = { [0 ... Q-1] = 0.0f }; float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance // TODO: see if we can utilize quad-group functions for better performance
...@@ -3497,7 +3497,7 @@ kernel void kernel_flash_attn_ext( ...@@ -3497,7 +3497,7 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) { for (ushort sg = 1; sg < nsg; ++sg) {
float S = { 0.0f }; float S = { 0.0f };
float M = { -__FLT16_MAX__/2 }; float M = { -__FLT_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
...@@ -3744,7 +3744,7 @@ kernel void kernel_flash_attn_ext_vec( ...@@ -3744,7 +3744,7 @@ kernel void kernel_flash_attn_ext_vec(
{ {
float S = 0.0f; float S = 0.0f;
float M = -__FLT16_MAX__/2; float M = -__FLT_MAX__/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup
const short tx = tiisg%NL; const short tx = tiisg%NL;
......
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment