Unverified Commit 8dd12c87 authored by Jeffrey Morgan's avatar Jeffrey Morgan Committed by GitHub
Browse files

llama: update to commit e1e8e099 (#10513)

parent e6d2d041
...@@ -2,6 +2,9 @@ ...@@ -2,6 +2,9 @@
#define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels. #define MMVQ_MAX_BATCH_SIZE 8 // Max. batch size for which to use MMVQ kernels.
void ggml_cuda_mul_mat_vec_q(ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst);
void ggml_cuda_op_mul_mat_vec_q( void ggml_cuda_op_mul_mat_vec_q(
ggml_backend_cuda_context & ctx, ggml_backend_cuda_context & ctx,
const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, const char * src0_dd_i, const float * src1_ddf_i,
......
#include "quantize.cuh" #include "quantize.cuh"
#include <cstdint> #include <cstdint>
static __global__ void quantize_q8_1(const float * __restrict__ x, void * __restrict__ vy, const int64_t kx, const int64_t kx0_padded) { static __global__ void quantize_q8_1(
const int64_t ix0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x; const float * __restrict__ x, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) {
const int64_t i0 = (int64_t)blockDim.x*blockIdx.x + threadIdx.x;
if (ix0 >= kx0_padded) { if (i0 >= ne0) {
return; return;
} }
const int64_t ix1 = blockIdx.y; const int64_t i1 = blockIdx.y;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t i_padded = ix1*kx0_padded + ix0; const int64_t & i00 = i0;
const int64_t & i01 = i1;
const int64_t & i02 = i2;
const int64_t & i03 = i3;
const int64_t i_cont = ((i3*ne2 + i2) * ne1 + i1) * ne0 + i0;
block_q8_1 * y = (block_q8_1 *) vy; block_q8_1 * y = (block_q8_1 *) vy;
const int64_t ib = i_padded / QK8_1; // block index const int64_t ib = i_cont / QK8_1; // block index
const int64_t iqs = i_padded % QK8_1; // quant index const int64_t iqs = i_cont % QK8_1; // quant index
const float xi = ix0 < kx ? x[ix1*kx + ix0] : 0.0f; const float xi = i0 < ne00 ? x[i03*s03 + i02*s02 + i01*s01 + i00] : 0.0f;
float amax = fabsf(xi); float amax = fabsf(xi);
float sum = xi; float sum = xi;
amax = warp_reduce_max(amax); amax = warp_reduce_max(amax);
sum = warp_reduce_sum(sum); sum = warp_reduce_sum(sum);
const float d = amax / 127; const float d = amax / 127;
const int8_t q = amax == 0.0f ? 0 : roundf(xi / d); const int8_t q = amax == 0.0f ? 0 : roundf(xi / d);
y[ib].qs[iqs] = q; y[ib].qs[iqs] = q;
...@@ -39,29 +49,38 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest ...@@ -39,29 +49,38 @@ static __global__ void quantize_q8_1(const float * __restrict__ x, void * __rest
template <mmq_q8_1_ds_layout ds_layout> template <mmq_q8_1_ds_layout ds_layout>
static __global__ void quantize_mmq_q8_1( static __global__ void quantize_mmq_q8_1(
const float * __restrict__ x, void * __restrict__ vy, const int64_t kx0, const int64_t kx1, const int64_t kx0_padded) { const float * __restrict__ x, const int32_t * __restrict__ ids, void * __restrict__ vy,
const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int ne1, const int ne2) {
constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32; constexpr int vals_per_scale = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 64 : 32;
constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32; constexpr int vals_per_sum = ds_layout == MMQ_Q8_1_DS_LAYOUT_D2S6 ? 16 : 32;
const int64_t ix0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4; const int64_t i0 = ((int64_t)blockDim.x*blockIdx.x + threadIdx.x)*4;
if (ix0 >= kx0_padded) { if (i0 >= ne0) {
return; return;
} }
const float4 * x4 = (const float4 *) x; const int64_t i1 = blockIdx.y;
const int64_t i2 = blockIdx.z % ne2;
const int64_t i3 = blockIdx.z / ne2;
const int64_t ix1 = kx1*blockIdx.z + blockIdx.y; const int64_t i00 = i0;
const int64_t i01 = ids ? ids[i1] : i1;
const int64_t i02 = i2;
const int64_t i03 = i3;
const float4 * x4 = (const float4 *) x;
block_q8_1_mmq * y = (block_q8_1_mmq *) vy; block_q8_1_mmq * y = (block_q8_1_mmq *) vy;
const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel const int64_t ib0 = blockIdx.z*((int64_t)gridDim.y*gridDim.x*blockDim.x/QK8_1); // first block of channel
const int64_t ib = ib0 + (ix0 / (4*QK8_1))*kx1 + blockIdx.y; // block index in channel const int64_t ib = ib0 + (i0 / (4*QK8_1))*ne1 + blockIdx.y; // block index in channel
const int64_t iqs = ix0 % (4*QK8_1); // quant index in block const int64_t iqs = i0 % (4*QK8_1); // quant index in block
// Load 4 floats per thread and calculate max. abs. value between them: // Load 4 floats per thread and calculate max. abs. value between them:
const float4 xi = ix0 < kx0 ? x4[(ix1*kx0 + ix0)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f); const float4 xi = i0 < ne00 ? x4[(i03*s03 + i02*s02 + i01*s01 + i00)/4] : make_float4(0.0f, 0.0f, 0.0f, 0.0f);
float amax = fabsf(xi.x); float amax = fabsf(xi.x);
amax = fmaxf(amax, fabsf(xi.y)); amax = fmaxf(amax, fabsf(xi.y));
amax = fmaxf(amax, fabsf(xi.z)); amax = fmaxf(amax, fabsf(xi.z));
...@@ -77,7 +96,7 @@ static __global__ void quantize_mmq_q8_1( ...@@ -77,7 +96,7 @@ static __global__ void quantize_mmq_q8_1(
if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) { if (ds_layout != MMQ_Q8_1_DS_LAYOUT_D4) {
sum = xi.x + xi.y + xi.z + xi.w; sum = xi.x + xi.y + xi.z + xi.w;
// Exchange calculate sum across vals_per_sum/4 threads. // Calculate sums across vals_per_sum/4 threads.
#pragma unroll #pragma unroll
for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) { for (int offset = vals_per_sum/8; offset > 0; offset >>= 1) {
sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE); sum += __shfl_xor_sync(0xFFFFFFFF, sum, offset, WARP_SIZE);
...@@ -127,40 +146,40 @@ static __global__ void quantize_mmq_q8_1( ...@@ -127,40 +146,40 @@ static __global__ void quantize_mmq_q8_1(
} }
void quantize_row_q8_1_cuda( void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(kx0_padded % QK8_1 == 0); GGML_ASSERT(!ids);
GGML_ASSERT(ne0 % QK8_1 == 0);
const int64_t block_num_x = (kx0_padded + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, kx1*channels, 1); const int64_t block_num_x = (ne0 + CUDA_QUANTIZE_BLOCK_SIZE - 1) / CUDA_QUANTIZE_BLOCK_SIZE;
const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE, 1, 1);
quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx0_padded); quantize_q8_1<<<num_blocks, block_size, 0, stream>>>(x, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
GGML_UNUSED(type_src0);
GGML_UNUSED(type_x);
} }
void quantize_mmq_q8_1_cuda( void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const float * x, const int32_t * ids, void * vy, const ggml_type type_src0,
const int64_t kx0_padded, const ggml_type type_x, cudaStream_t stream) { const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
const int64_t ne0, const int64_t ne1, const int64_t ne2, const int64_t ne3, cudaStream_t stream) {
GGML_ASSERT(kx0_padded % (4*QK8_1) == 0); GGML_ASSERT(ne0 % (4*QK8_1) == 0);
const int64_t block_num_x = (kx0_padded + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ); const int64_t block_num_x = (ne0 + 4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ - 1) / (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ);
const dim3 num_blocks(block_num_x, kx1, channels); const dim3 num_blocks(block_num_x, ne1, ne2*ne3);
const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1); const dim3 block_size(CUDA_QUANTIZE_BLOCK_SIZE_MMQ, 1, 1);
switch (mmq_get_q8_1_ds_layout(type_x)) { switch (mmq_get_q8_1_ds_layout(type_src0)) {
case MMQ_Q8_1_DS_LAYOUT_D4: case MMQ_Q8_1_DS_LAYOUT_D4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4> quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break; break;
case MMQ_Q8_1_DS_LAYOUT_DS4: case MMQ_Q8_1_DS_LAYOUT_DS4:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4> quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_DS4>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break; break;
case MMQ_Q8_1_DS_LAYOUT_D2S6: case MMQ_Q8_1_DS_LAYOUT_D2S6:
quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6> quantize_mmq_q8_1<MMQ_Q8_1_DS_LAYOUT_D2S6>
<<<num_blocks, block_size, 0, stream>>>(x, vy, kx0, kx1, kx0_padded); <<<num_blocks, block_size, 0, stream>>>(x, ids, vy, ne00, s01, s02, s03, ne0, ne1, ne2);
break; break;
default: default:
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
......
...@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk ...@@ -12,13 +12,16 @@ static_assert(MATRIX_ROW_PADDING % CUDA_QUANTIZE_BLOCK_SIZE == 0, "Risk
static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access."); static_assert(MATRIX_ROW_PADDING % (4*CUDA_QUANTIZE_BLOCK_SIZE_MMQ) == 0, "Risk of out-of-bounds access.");
typedef void (*quantize_cuda_t)( typedef void (*quantize_cuda_t)(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, const int32_t * ids, void * vy,
const ggml_type type_x, cudaStream_t stream); ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_row_q8_1_cuda( void quantize_row_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, const int32_t * ids, void * vy,
const ggml_type type_x, cudaStream_t stream); ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
void quantize_mmq_q8_1_cuda( void quantize_mmq_q8_1_cuda(
const float * x, void * vy, const int64_t kx0, const int64_t kx1, const int64_t channels, const int64_t kx0_padded, const float * x, const int32_t * ids, void * vy,
const ggml_type type_x, cudaStream_t stream); ggml_type type_src0, int64_t ne00, int64_t s01, int64_t s02, int64_t s03,
int64_t ne0, int64_t ne1, int64_t ne2, int64_t ne3, cudaStream_t stream);
#pragma once
#include "common.cuh" #include "common.cuh"
#include <cstdint> #include <cstdint>
......
...@@ -5690,7 +5690,7 @@ kernel void kernel_flash_attn_ext( ...@@ -5690,7 +5690,7 @@ kernel void kernel_flash_attn_ext(
{ {
float S[Q] = { [0 ... Q-1] = 0.0f }; float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance // TODO: see if we can utilize quad-group functions for better performance
...@@ -5950,7 +5950,7 @@ kernel void kernel_flash_attn_ext( ...@@ -5950,7 +5950,7 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) { for (ushort sg = 1; sg < nsg; ++sg) {
float S = { 0.0f }; float S = { 0.0f };
float M = { -__FLT16_MAX__/2 }; float M = { -__FLT_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
...@@ -6197,7 +6197,7 @@ kernel void kernel_flash_attn_ext_vec( ...@@ -6197,7 +6197,7 @@ kernel void kernel_flash_attn_ext_vec(
{ {
float S = 0.0f; float S = 0.0f;
float M = -__FLT16_MAX__/2; float M = -__FLT_MAX__/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup
const short tx = tiisg%NL; const short tx = tiisg%NL;
......
...@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device; ...@@ -44,8 +44,8 @@ static struct ggml_backend_device g_ggml_backend_metal_device;
// note: assumes single GPU device - the default one // note: assumes single GPU device - the default one
// TODO: support multiple GPU devices // TODO: support multiple GPU devices
static struct ggml_backend_metal_device_context { static struct ggml_backend_metal_device_context {
id<MTLDevice> mtl_device; id<MTLDevice> mtl_device;
int mtl_device_ref_count; int mtl_device_ref_count;
id<MTLLibrary> mtl_library; id<MTLLibrary> mtl_library;
bool has_simdgroup_reduction; bool has_simdgroup_reduction;
...@@ -491,7 +491,259 @@ enum ggml_metal_kernel_type { ...@@ -491,7 +491,259 @@ enum ggml_metal_kernel_type {
GGML_METAL_KERNEL_TYPE_COUNT GGML_METAL_KERNEL_TYPE_COUNT
}; };
//
// ggml_metal_heap
//
struct ggml_metal_heap {
// number of times the heap was unused
int n_unused;
// total number of buffer allocations in this heap across all computes
int64_t n_alloc;
// current offset in the heap - we reset this after each node in order to reuse the memory
size_t offs;
// the currently allocated MTLBuffer objects in this heap
id<MTLHeap> obj;
NSMutableArray * bufs;
};
static struct ggml_metal_heap * ggml_metal_heap_init(id<MTLDevice> device, size_t size) {
struct ggml_metal_heap * heap = calloc(1, sizeof(struct ggml_metal_heap));
MTLHeapDescriptor * desc = [[MTLHeapDescriptor alloc] init];
desc.storageMode = MTLStorageModePrivate;
desc.cpuCacheMode = MTLCPUCacheModeDefaultCache;
desc.type = MTLHeapTypePlacement;
desc.size = size;
heap->n_unused = 0;
heap->n_alloc = 0;
heap->obj = [device newHeapWithDescriptor:desc];
if (!heap->obj) {
GGML_LOG_ERROR("%s: error: failed to create MTLHeap with size %zu\n", __func__, size);
free(heap);
return false;
}
[desc release];
heap->bufs = [[NSMutableArray alloc] init];
return heap;
}
static void ggml_metal_heap_reset(struct ggml_metal_heap * heap) {
heap->offs = 0;
// count how many graph computes the heap ended up being unused
if ([heap->bufs count] > 0) {
heap->n_unused = 0;
} else {
heap->n_unused++;
}
for (id<MTLBuffer> buf in heap->bufs) {
[buf release];
}
[heap->bufs removeAllObjects];
// tell the OS that it can reuse this memory if needed
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
[heap->obj setPurgeableState:MTLPurgeableStateVolatile];
}
static void ggml_metal_heap_free(struct ggml_metal_heap * heap) {
if (heap == nil) {
return;
}
ggml_metal_heap_reset(heap);
[heap->obj release];
[heap->bufs release];
free(heap);
}
@interface ggml_metal_heap_ptr : NSObject
@property (nonatomic, assign) struct ggml_metal_heap * data;
@end
@implementation ggml_metal_heap_ptr
@end
//
// ggml_metal_mem_pool
//
struct ggml_metal_mem_pool {
id<MTLDevice> device;
int n_heaps; // total number of heaps ever created (including those that were removed)
NSMutableArray * heaps;
NSMutableArray * heaps_to_remove;
};
static struct ggml_metal_mem_pool * ggml_metal_mem_pool_init(void) {
struct ggml_metal_mem_pool * mem_pool = calloc(1, sizeof(struct ggml_metal_mem_pool));
mem_pool->n_heaps = 0;
mem_pool->heaps = [[NSMutableArray alloc] init];
mem_pool->heaps_to_remove = [[NSMutableArray alloc] init];
return mem_pool;
}
static void ggml_metal_mem_pool_free(struct ggml_metal_mem_pool * mem_pool) {
GGML_LOG_DEBUG("%s: freeing memory pool, num heaps = %zu (total = %d)\n", __func__, [mem_pool->heaps count], mem_pool->n_heaps);
size_t size_all = 0;
size_t size_cur = 0;
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
GGML_LOG_DEBUG("%s: heap: %p\n", __func__, (void *) ptr.data);
GGML_LOG_DEBUG("%s: n_alloc: %" PRId64 "\n", __func__, ptr.data->n_alloc);
GGML_LOG_DEBUG("%s: n_unused: %d\n", __func__, ptr.data->n_unused);
GGML_LOG_DEBUG("%s: size: %.2f MiB\n", __func__, [ptr.data->obj size] / 1024.0 / 1024.0);
GGML_LOG_DEBUG("%s: bufs: %zu\n", __func__, [ptr.data->bufs count]);
if ([ptr.data->bufs count] > 0) {
size_cur += [ptr.data->obj size];
}
size_all += [ptr.data->obj size];
ggml_metal_heap_free(ptr.data);
[ptr release];
}
[mem_pool->heaps release];
[mem_pool->heaps_to_remove release];
if (size_all > 0) {
GGML_LOG_DEBUG("%s: size_all: %.2f MiB\n", __func__, size_all / 1024.0 / 1024.0);
GGML_LOG_DEBUG("%s: size_cur: %.2f MiB\n", __func__, size_cur / 1024.0 / 1024.0);
}
free(mem_pool);
}
static void ggml_metal_mem_pool_reset(struct ggml_metal_mem_pool * mem_pool) {
for (NSUInteger i = 0; i < [mem_pool->heaps count]; i++) {
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:i];
struct ggml_metal_heap * heap = ptr.data;
ggml_metal_heap_reset(heap);
// if the heap hasn't been used for a while, remove it
if (heap->n_unused >= 128) {
[mem_pool->heaps_to_remove addObject:@(i)];
}
}
if (mem_pool->heaps_to_remove.count > 0) {
for (NSUInteger i = 0; i < [mem_pool->heaps_to_remove count]; i++) {
NSUInteger index = [[mem_pool->heaps_to_remove objectAtIndex:i] intValue];
ggml_metal_heap_ptr * ptr = [mem_pool->heaps objectAtIndex:index];
struct ggml_metal_heap * heap = ptr.data;
ggml_metal_heap_free(heap);
[mem_pool->heaps removeObjectAtIndex:index];
[ptr release];
}
[mem_pool->heaps_to_remove removeAllObjects];
}
}
static void ggml_metal_mem_pool_clear(struct ggml_metal_mem_pool * mem_pool) {
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
ptr.data->offs = 0;
}
}
static id<MTLBuffer> ggml_metal_mem_pool_alloc(struct ggml_metal_mem_pool * mem_pool, size_t size) {
const size_t alignment = 32;
const size_t size_aligned = GGML_PAD(size, alignment);
// try one of the existing heaps
for (ggml_metal_heap_ptr * ptr in mem_pool->heaps) {
struct ggml_metal_heap * heap = ptr.data;
if (heap->offs + size_aligned <= [heap->obj size]) {
// if this is the first buffer in the heap for the current command buffer, tell the OS that
// it cannot free the memory used by the heap
// ref: https://developer.apple.com/documentation/metal/mtlpurgeablestate?language=objc
if ([heap->bufs count] == 0) {
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
}
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
if (buf == nil) {
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
return nil;
}
heap->n_alloc++;
heap->offs += size_aligned;
[heap->bufs addObject:buf];
return buf;
}
}
// create a new heap that can fit this buffer
ggml_metal_heap_ptr * heap_ptr = [ggml_metal_heap_ptr new];
struct ggml_metal_heap * heap = ggml_metal_heap_init(mem_pool->device, size_aligned);
if (heap == NULL) {
GGML_LOG_ERROR("%s: error: failed to create heap of size %zu\n", __func__, size_aligned);
return NULL;
}
//GGML_LOG_DEBUG("%s: creating new heap of size %zu, got %zu\n", __func__, size_aligned, [heap->obj size]);
heap_ptr.data = heap;
ggml_metal_heap_reset(heap);
[heap->obj setPurgeableState:MTLPurgeableStateNonVolatile];
id<MTLBuffer> buf = [heap->obj newBufferWithLength:size_aligned options:MTLResourceStorageModePrivate offset:heap->offs];
if (buf == nil) {
GGML_LOG_ERROR("%s: error: failed to create MTLBuffer with size %zu\n", __func__, size_aligned);
return NULL;
}
heap->n_alloc++;
heap->offs += size_aligned;
[heap->bufs addObject:buf];
[mem_pool->heaps addObject:heap_ptr];
mem_pool->n_heaps++;
return buf;
}
struct ggml_metal_command_buffer {
id<MTLCommandBuffer> obj;
// each command buffer has a memory pool from which it can allocate temporary buffers during the compute
struct ggml_metal_mem_pool * mem_pool;
};
struct ggml_backend_metal_context { struct ggml_backend_metal_context {
id<MTLDevice> device;
id<MTLCommandQueue> queue; id<MTLCommandQueue> queue;
dispatch_queue_t d_queue; dispatch_queue_t d_queue;
...@@ -516,7 +768,7 @@ struct ggml_backend_metal_context { ...@@ -516,7 +768,7 @@ struct ggml_backend_metal_context {
void (^encode_async)(size_t ith); void (^encode_async)(size_t ith);
// n_cb command buffers + 1 used by the main thread // n_cb command buffers + 1 used by the main thread
id<MTLCommandBuffer> command_buffers[GGML_METAL_MAX_COMMAND_BUFFERS + 1]; struct ggml_metal_command_buffer cmd_bufs[GGML_METAL_MAX_COMMAND_BUFFERS + 1];
// abort ggml_metal_graph_compute if callback returns true // abort ggml_metal_graph_compute if callback returns true
ggml_abort_callback abort_callback; ggml_abort_callback abort_callback;
...@@ -706,9 +958,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ...@@ -706,9 +958,11 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
struct ggml_backend_metal_device_context * ctx_dev = dev->context; struct ggml_backend_metal_device_context * ctx_dev = dev->context;
id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev); id<MTLDevice> device = ggml_backend_metal_device_acq(ctx_dev);
GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]); GGML_LOG_INFO("%s: picking default device: %s\n", __func__, [[device name] UTF8String]);
ctx->queue = [device newCommandQueue]; ctx->device = device;
ctx->queue = [device newCommandQueue];
if (ctx->queue == nil) { if (ctx->queue == nil) {
GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__); GGML_LOG_ERROR("%s: error: failed to create command queue\n", __func__);
return NULL; return NULL;
...@@ -769,7 +1023,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de ...@@ -769,7 +1023,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
ctx->gf = nil; ctx->gf = nil;
ctx->encode_async = nil; ctx->encode_async = nil;
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) { for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
ctx->command_buffers[i] = nil; ctx->cmd_bufs[i].obj = nil;
ctx->cmd_bufs[i].mem_pool = ggml_metal_mem_pool_init();
ctx->cmd_bufs[i].mem_pool->device = device;
} }
#if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15) #if TARGET_OS_OSX || (TARGET_OS_IOS && __clang_major__ >= 15)
...@@ -1183,6 +1440,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) { ...@@ -1183,6 +1440,12 @@ static void ggml_metal_free(struct ggml_backend_metal_context * ctx) {
[ctx->queue release]; [ctx->queue release];
for (int i = 0; i < GGML_METAL_MAX_COMMAND_BUFFERS; ++i) {
// ctx->cmd_bufs[i].obj is auto released
ggml_metal_mem_pool_free(ctx->cmd_bufs[i].mem_pool);
}
dispatch_release(ctx->d_queue); dispatch_release(ctx->d_queue);
free(ctx); free(ctx);
...@@ -1489,10 +1752,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex ...@@ -1489,10 +1752,11 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
} }
} }
static void ggml_metal_encode_node( static bool ggml_metal_encode_node(
ggml_backend_t backend, ggml_backend_t backend,
int idx, int idx,
id<MTLComputeCommandEncoder> encoder) { id<MTLComputeCommandEncoder> encoder,
struct ggml_metal_mem_pool * mem_pool) {
struct ggml_backend_metal_context * ctx = backend->context; struct ggml_backend_metal_context * ctx = backend->context;
struct ggml_backend_metal_device_context * ctx_dev = backend->device->context; struct ggml_backend_metal_device_context * ctx_dev = backend->device->context;
...@@ -1508,7 +1772,7 @@ static void ggml_metal_encode_node( ...@@ -1508,7 +1772,7 @@ static void ggml_metal_encode_node(
struct ggml_tensor * dst = node; struct ggml_tensor * dst = node;
if (ggml_is_empty(dst)) { if (ggml_is_empty(dst)) {
return; return true;
} }
switch (dst->op) { switch (dst->op) {
...@@ -1519,7 +1783,7 @@ static void ggml_metal_encode_node( ...@@ -1519,7 +1783,7 @@ static void ggml_metal_encode_node(
case GGML_OP_PERMUTE: case GGML_OP_PERMUTE:
{ {
// noop -> next node // noop -> next node
} return; } return true;
default: default:
{ {
} break; } break;
...@@ -1530,6 +1794,8 @@ static void ggml_metal_encode_node( ...@@ -1530,6 +1794,8 @@ static void ggml_metal_encode_node(
GGML_ABORT("unsupported op"); GGML_ABORT("unsupported op");
} }
ggml_metal_mem_pool_clear(mem_pool);
const int64_t ne00 = src0 ? src0->ne[0] : 0; const int64_t ne00 = src0 ? src0->ne[0] : 0;
const int64_t ne01 = src0 ? src0->ne[1] : 0; const int64_t ne01 = src0 ? src0->ne[1] : 0;
const int64_t ne02 = src0 ? src0->ne[2] : 0; const int64_t ne02 = src0 ? src0->ne[2] : 0;
...@@ -2176,26 +2442,76 @@ static void ggml_metal_encode_node( ...@@ -2176,26 +2442,76 @@ static void ggml_metal_encode_node(
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); const float m0 = powf(2.0f, -(max_bias ) / n_head_log2);
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2);
ggml_metal_kargs_soft_max args = { // use this branch to test the ggml_metal_mem_pool functionality
#if 0
// cpy to tmp buffer in MTLHeap
id<MTLBuffer> h_src0 = h_src0 = ggml_metal_mem_pool_alloc(mem_pool, ggml_nbytes(src0));
if (!h_src0) {
GGML_LOG_ERROR("%s: failed to allocate buffer from memory pool, size = %zu\n", __func__, ggml_nbytes(src0));
return false;
}
offs_src0 = 0;
ggml_metal_kargs_cpy args_cpy = {
/*.ne00 =*/ ne00, /*.ne00 =*/ ne00,
/*.ne01 =*/ ne01, /*.ne01 =*/ ne01,
/*.ne02 =*/ ne02, /*.ne02 =*/ ne02,
/*.scale =*/ scale, /*.ne03 =*/ ne03,
/*.max_bias =*/ max_bias, /*.nb00 =*/ nb00,
/*.m0 =*/ m0, /*.nb01 =*/ nb01,
/*.m1 =*/ m1, /*.nb02 =*/ nb02,
/*.nb03 =*/ nb03,
/*.ne0 =*/ ne00,
/*.ne1 =*/ ne01,
/*.ne2 =*/ ne02,
/*.ne3 =*/ ne03,
/*.nb0 =*/ nb00,
/*.nb1 =*/ nb01,
/*.nb2 =*/ nb02,
/*.nb3 =*/ nb03,
};
if (src0->type == GGML_TYPE_F16) {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F16_F16].pipeline];
} else {
[encoder setComputePipelineState:ctx->kernels[GGML_METAL_KERNEL_TYPE_CPY_F32_F32].pipeline];
}
[encoder setBytes:&args_cpy length:sizeof(args_cpy) atIndex:0];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1];
[encoder setBuffer:h_src0 offset:0 atIndex:2];
GGML_ASSERT(ne00 % ggml_blck_size(src0->type) == 0);
int nth_cpy = MIN(1024, ne00 / ggml_blck_size(src0->type));
[encoder dispatchThreadgroups:MTLSizeMake(ne01, ne02, ne03) threadsPerThreadgroup:MTLSizeMake(nth_cpy, 1, 1)];
#else
id<MTLBuffer> h_src0 = id_src0;
#endif
// softmax
ggml_metal_kargs_soft_max args = {
/*.ne00 =*/ ne00,
/*.ne01 =*/ ne01,
/*.ne02 =*/ ne02,
/*.scale =*/ scale,
/*.max_bias =*/ max_bias,
/*.m0 =*/ m0,
/*.m1 =*/ m1,
/*.n_head_log2 =*/ n_head_log2, /*.n_head_log2 =*/ n_head_log2,
}; };
[encoder setComputePipelineState:pipeline]; [encoder setComputePipelineState:pipeline];
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:0]; [encoder setBuffer:h_src0 offset:offs_src0 atIndex:0];
if (id_src1) { if (id_src1) {
[encoder setBuffer:id_src1 offset:offs_src1 atIndex:1]; [encoder setBuffer:id_src1 offset:offs_src1 atIndex:1];
} else { } else {
[encoder setBuffer:id_src0 offset:offs_src0 atIndex:1]; [encoder setBuffer:h_src0 offset:offs_src0 atIndex:1];
} }
[encoder setBuffer:id_dst offset:offs_dst atIndex:2]; [encoder setBuffer:id_dst offset:offs_dst atIndex:2];
[encoder setBytes:&args length:sizeof(args) atIndex:3]; [encoder setBytes:&args length:sizeof(args) atIndex:3];
[encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0]; [encoder setThreadgroupMemoryLength:32*sizeof(float) atIndex:0];
...@@ -4634,6 +4950,8 @@ static void ggml_metal_encode_node( ...@@ -4634,6 +4950,8 @@ static void ggml_metal_encode_node(
GGML_ABORT("fatal error"); GGML_ABORT("fatal error");
} }
} }
return true;
} }
static enum ggml_status ggml_metal_graph_compute( static enum ggml_status ggml_metal_graph_compute(
...@@ -4687,25 +5005,25 @@ static enum ggml_status ggml_metal_graph_compute( ...@@ -4687,25 +5005,25 @@ static enum ggml_status ggml_metal_graph_compute(
} }
// the main thread commits the first few commands immediately // the main thread commits the first few commands immediately
// command_buffer[n_cb] // cmd_buf[n_cb]
{ {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[n_cb] = command_buffer; ctx->cmd_bufs[n_cb].obj = cmd_buf;
[command_buffer enqueue]; [cmd_buf enqueue];
ctx->encode_async(n_cb); ctx->encode_async(n_cb);
} }
// prepare the rest of the command buffers asynchronously // prepare the rest of the command buffers asynchronously
// command_buffer[0.. n_cb) // cmd_buf[0.. n_cb)
for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) { for (int cb_idx = 0; cb_idx < n_cb; ++cb_idx) {
id<MTLCommandBuffer> command_buffer = [ctx->queue commandBufferWithUnretainedReferences]; id<MTLCommandBuffer> cmd_buf = [ctx->queue commandBufferWithUnretainedReferences];
ctx->command_buffers[cb_idx] = command_buffer; ctx->cmd_bufs[cb_idx].obj = cmd_buf;
// always enqueue the first two command buffers // always enqueue the first two command buffers
// enqueue all of the command buffers if we don't need to abort // enqueue all of the command buffers if we don't need to abort
if (cb_idx < 2 || ctx->abort_callback == NULL) { if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer enqueue]; [cmd_buf enqueue];
} }
} }
...@@ -4714,14 +5032,14 @@ static enum ggml_status ggml_metal_graph_compute( ...@@ -4714,14 +5032,14 @@ static enum ggml_status ggml_metal_graph_compute(
// wait for completion and check status of each command buffer // wait for completion and check status of each command buffer
// needed to detect if the device ran out-of-memory for example (#1881) // needed to detect if the device ran out-of-memory for example (#1881)
{ {
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[n_cb]; id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[n_cb].obj;
[command_buffer waitUntilCompleted]; [cmd_buf waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status]; MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) { if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status); GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, n_cb, status);
if (status == MTLCommandBufferStatusError) { if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
} }
return GGML_STATUS_FAILED; return GGML_STATUS_FAILED;
...@@ -4729,20 +5047,20 @@ static enum ggml_status ggml_metal_graph_compute( ...@@ -4729,20 +5047,20 @@ static enum ggml_status ggml_metal_graph_compute(
} }
for (int i = 0; i < n_cb; ++i) { for (int i = 0; i < n_cb; ++i) {
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[i]; id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[i].obj;
[command_buffer waitUntilCompleted]; [cmd_buf waitUntilCompleted];
MTLCommandBufferStatus status = [command_buffer status]; MTLCommandBufferStatus status = [cmd_buf status];
if (status != MTLCommandBufferStatusCompleted) { if (status != MTLCommandBufferStatusCompleted) {
GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status); GGML_LOG_INFO("%s: command buffer %d failed with status %lu\n", __func__, i, status);
if (status == MTLCommandBufferStatusError) { if (status == MTLCommandBufferStatusError) {
GGML_LOG_INFO("error: %s\n", [[command_buffer error].localizedDescription UTF8String]); GGML_LOG_INFO("error: %s\n", [[cmd_buf error].localizedDescription UTF8String]);
} }
return GGML_STATUS_FAILED; return GGML_STATUS_FAILED;
} }
id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->command_buffers[i + 1] : nil); id<MTLCommandBuffer> next_buffer = (i + 1 < n_cb ? ctx->cmd_bufs[i + 1].obj : nil);
if (!next_buffer) { if (!next_buffer) {
continue; continue;
} }
...@@ -5126,8 +5444,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ...@@ -5126,8 +5444,9 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const int n_nodes_per_cb = ctx->n_nodes_per_cb; const int n_nodes_per_cb = ctx->n_nodes_per_cb;
id<MTLCommandBuffer> command_buffer = ctx->command_buffers[cb_idx]; id<MTLCommandBuffer> cmd_buf = ctx->cmd_bufs[cb_idx].obj;
id<MTLComputeCommandEncoder> encoder = [command_buffer computeCommandEncoder];
id<MTLComputeCommandEncoder> encoder = [cmd_buf computeCommandEncoder];
int node_start = 0; int node_start = 0;
int node_end = n_nodes_0; int node_end = n_nodes_0;
...@@ -5139,22 +5458,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) { ...@@ -5139,22 +5458,29 @@ static void ggml_backend_metal_set_n_cb(ggml_backend_t backend, int n_cb) {
const bool should_capture = ctx->capture_next_compute; const bool should_capture = ctx->capture_next_compute;
struct ggml_metal_mem_pool * mem_pool = ctx->cmd_bufs[cb_idx].mem_pool;
ggml_metal_mem_pool_reset(mem_pool);
for (int idx = node_start; idx < node_end; ++idx) { for (int idx = node_start; idx < node_end; ++idx) {
if (should_capture) { if (should_capture) {
[encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]]; [encoder pushDebugGroup:[NSString stringWithCString:ggml_op_desc(ggml_graph_node(ctx->gf, idx)) encoding:NSUTF8StringEncoding]];
} }
ggml_metal_encode_node(backend, idx, encoder); const bool res = ggml_metal_encode_node(backend, idx, encoder, mem_pool);
if (should_capture) { if (should_capture) {
[encoder popDebugGroup]; [encoder popDebugGroup];
} }
if (!res) {
break;
}
} }
[encoder endEncoding]; [encoder endEncoding];
if (cb_idx < 2 || ctx->abort_callback == NULL) { if (cb_idx < 2 || ctx->abort_callback == NULL) {
[command_buffer commit]; [cmd_buf commit];
} }
}); });
} }
......
...@@ -3237,7 +3237,7 @@ kernel void kernel_flash_attn_ext( ...@@ -3237,7 +3237,7 @@ kernel void kernel_flash_attn_ext(
{ {
float S[Q] = { [0 ... Q-1] = 0.0f }; float S[Q] = { [0 ... Q-1] = 0.0f };
float M[Q] = { [0 ... Q-1] = -__FLT16_MAX__/2 }; float M[Q] = { [0 ... Q-1] = -__FLT_MAX__/2 };
// thread indices inside the simdgroup // thread indices inside the simdgroup
// TODO: see if we can utilize quad-group functions for better performance // TODO: see if we can utilize quad-group functions for better performance
...@@ -3497,7 +3497,7 @@ kernel void kernel_flash_attn_ext( ...@@ -3497,7 +3497,7 @@ kernel void kernel_flash_attn_ext(
// reduce the warps sequentially // reduce the warps sequentially
for (ushort sg = 1; sg < nsg; ++sg) { for (ushort sg = 1; sg < nsg; ++sg) {
float S = { 0.0f }; float S = { 0.0f };
float M = { -__FLT16_MAX__/2 }; float M = { -__FLT_MAX__/2 };
threadgroup_barrier(mem_flags::mem_threadgroup); threadgroup_barrier(mem_flags::mem_threadgroup);
...@@ -3744,7 +3744,7 @@ kernel void kernel_flash_attn_ext_vec( ...@@ -3744,7 +3744,7 @@ kernel void kernel_flash_attn_ext_vec(
{ {
float S = 0.0f; float S = 0.0f;
float M = -__FLT16_MAX__/2; float M = -__FLT_MAX__/2;
// thread indices inside the simdgroup // thread indices inside the simdgroup
const short tx = tiisg%NL; const short tx = tiisg%NL;
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#include "ggml-backend.h" #include "ggml-backend.h"
#include "ggml-impl.h" #include "ggml-impl.h"
#include "ggml-threading.h" #include "ggml-threading.h"
#include "ggml-cpu.h"
#include "ggml.h" #include "ggml.h"
// FIXME: required here for quantization functions // FIXME: required here for quantization functions
...@@ -382,58 +383,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) { ...@@ -382,58 +383,16 @@ void ggml_fp16_to_fp32_row(const ggml_fp16_t * x, float * y, int64_t n) {
} }
} }
// FIXME: these functions must detect the instruction set at runtime, since they are part of the core ggml library
// currently, the ggml_cpu_has_* functions are entirely compile-time
void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) { void ggml_fp32_to_fp16_row(const float * x, ggml_fp16_t * y, int64_t n) {
int64_t i = 0; int i = 0;
#if defined(__F16C__) for (; i < n; ++i) {
//if (ggml_cpu_has_f16c()) {
for (; i + 7 < n; i += 8) {
__m256 x_vec = _mm256_loadu_ps(x + i);
__m128i y_vec = _mm256_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storeu_si128((__m128i *)(y + i), y_vec);
}
for(; i + 3 < n; i += 4) {
__m128 x_vec = _mm_loadu_ps(x + i);
__m128i y_vec = _mm_cvtps_ph(x_vec, _MM_FROUND_TO_NEAREST_INT);
_mm_storel_epi64((__m128i *)(y + i), y_vec);
}
//}
#endif
for (; i < n; i++) {
y[i] = GGML_FP32_TO_FP16(x[i]); y[i] = GGML_FP32_TO_FP16(x[i]);
} }
} }
void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) { void ggml_bf16_to_fp32_row(const ggml_bf16_t * x, float * y, int64_t n) {
int64_t i = 0; int i = 0;
#if defined(__AVX512F__) for (; i < n; ++i) {
//if (ggml_cpu_has_avx512()) {
for (; i + 16 <= n; i += 16) {
_mm512_storeu_ps(y + i,
_mm512_castsi512_ps(
_mm512_slli_epi32(
_mm512_cvtepu16_epi32(
_mm256_loadu_si256(
(const __m256i *)(x + i))),
16)));
}
//}
#endif
#if defined(__AVX2__)
//if (ggml_cpu_has_avx2()) {
for (; i + 8 <= n; i += 8) {
_mm256_storeu_ps(y + i,
_mm256_castsi256_ps(
_mm256_slli_epi32(
_mm256_cvtepu16_epi32(
_mm_loadu_si128(
(const __m128i *)(x + i))),
16)));
}
//}
#endif
for (; i < n; i++) {
y[i] = GGML_BF16_TO_FP32(x[i]); y[i] = GGML_BF16_TO_FP32(x[i]);
} }
} }
...@@ -956,6 +915,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { ...@@ -956,6 +915,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"CONV_TRANSPOSE_1D", "CONV_TRANSPOSE_1D",
"IM2COL", "IM2COL",
"IM2COL_BACK", "IM2COL_BACK",
"CONV_2D_DW",
"CONV_TRANSPOSE_2D", "CONV_TRANSPOSE_2D",
"POOL_1D", "POOL_1D",
"POOL_2D", "POOL_2D",
...@@ -994,7 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = { ...@@ -994,7 +954,7 @@ static const char * GGML_OP_NAME[GGML_OP_COUNT] = {
"OPT_STEP_ADAMW", "OPT_STEP_ADAMW",
}; };
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"none", "none",
...@@ -1051,6 +1011,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { ...@@ -1051,6 +1011,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"conv_transpose_1d(x)", "conv_transpose_1d(x)",
"im2col(x)", "im2col(x)",
"im2col_back(x)", "im2col_back(x)",
"conv_2d_dw(x)",
"conv_transpose_2d(x)", "conv_transpose_2d(x)",
"pool_1d(x)", "pool_1d(x)",
"pool_2d(x)", "pool_2d(x)",
...@@ -1089,7 +1050,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = { ...@@ -1089,7 +1050,7 @@ static const char * GGML_OP_SYMBOL[GGML_OP_COUNT] = {
"adamw(x)", "adamw(x)",
}; };
static_assert(GGML_OP_COUNT == 82, "GGML_OP_COUNT != 82"); static_assert(GGML_OP_COUNT == 83, "GGML_OP_COUNT != 83");
static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2"); static_assert(GGML_OP_POOL_COUNT == 2, "GGML_OP_POOL_COUNT != 2");
...@@ -1346,6 +1307,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) { ...@@ -1346,6 +1307,13 @@ bool ggml_is_permuted(const struct ggml_tensor * tensor) {
return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3]; return tensor->nb[0] > tensor->nb[1] || tensor->nb[1] > tensor->nb[2] || tensor->nb[2] > tensor->nb[3];
} }
bool ggml_is_contiguous_channels(const struct ggml_tensor * tensor) {
return
tensor->nb[0] > tensor->nb[2] &&
tensor->nb[1] > tensor->nb[0] &&
tensor->nb[2] == ggml_type_size(tensor->type);
}
static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) { static inline bool ggml_is_padded_1d(const struct ggml_tensor * tensor) {
static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function"); static_assert(GGML_MAX_DIMS == 4, "GGML_MAX_DIMS is not 4 - update this function");
...@@ -4052,6 +4020,46 @@ struct ggml_tensor * ggml_conv_2d_dw( ...@@ -4052,6 +4020,46 @@ struct ggml_tensor * ggml_conv_2d_dw(
return result; return result;
} }
// ggml_conv_2d_dw_direct
struct ggml_tensor * ggml_conv_2d_dw_direct(
struct ggml_context * ctx,
struct ggml_tensor * a,
struct ggml_tensor * b,
int stride0,
int stride1,
int pad0,
int pad1,
int dilation0,
int dilation1) {
GGML_ASSERT(a->ne[2] == 1);
GGML_ASSERT(a->ne[3] == b->ne[2]);
int64_t ne[4];
ne[0] = ggml_calc_conv_output_size(b->ne[0], a->ne[0], stride0, pad0, dilation0);
ne[1] = ggml_calc_conv_output_size(b->ne[1], a->ne[1], stride1, pad1, dilation1);
ne[2] = b->ne[2];
ne[3] = b->ne[3];
struct ggml_tensor * result = ggml_new_tensor(ctx, b->type, 4, ne);
if (ggml_is_contiguous_channels(b)) {
// Result will be permuted the same way as input (CWHN order)
const int64_t type_size = ggml_type_size(result->type);
GGML_ASSERT(ggml_blck_size(result->type) == 1);
result->nb[0] = result->ne[2] * type_size;
result->nb[1] = result->ne[0] * result->nb[0];
result->nb[2] = type_size;
}
int32_t params[] = { stride0, stride1, pad0, pad1, dilation0, dilation1 };
ggml_set_op_params(result, params, sizeof(params));
result->op = GGML_OP_CONV_2D_DW;
result->src[0] = a;
result->src[1] = b;
return result;
}
// ggml_conv_transpose_2d_p0 // ggml_conv_transpose_2d_p0
static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) { static int64_t ggml_calc_conv_transpose_output_size(int64_t ins, int64_t ks, int s, int p) {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment