Unverified Commit 2aba569a authored by Thomas Stocker's avatar Thomas Stocker Committed by GitHub
Browse files

Vulkan based on #9650 (#11835)

* implement the vulkan C backend

* add support in gpu.go

* add support in gen_linux.sh

* it builds

* fix segfault

* fix compilation

* fix free memory monitor

* fix total memory monitor

* update gpu.go

* fix build

* fix check_perfmon len

* remove cap_get_bound check

* fix vulkan handle releasing

* fix build on federa 40

* fix vulkan on windows

* making amdgpu work on arm achitecutre with vulkan

* add x86_64 lines in VulkanGlobs and capLinuxGlobs

* add aarch64 lines in vulkanGlobs and capLinuxGlobs

* Fix variable name

* Add vulkan build patch from @jmorganca

* Sync vendored ggml to add Vulkan support

* Updated dockerfile

https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Installing rocm library
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* This version works well

built based on this: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Applied 00-fix-vulkan-building.patch

Work done by McBane87 here: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Fixed the "detached head" issues
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Merged in the right direction
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Merging the latest stable (#2)

* Applied 00-fix-vulkan-building.patch

* Implemented vulkan backend based on the work done by whyvl, Dts0, McBane87 and others

Tested on AMD Ryzen 7 8845HS w/ Radeon 780M Graphics with ROCm disabled

```
[GIN-debug] POST   /v1/chat/completions      --> github.com/ollama/ollama/server.(*Server).ChatHandler-fm (6 handlers)
[GIN-debug] POST   /v1/completions           --> github.com/ollama/ollama/server.(*Server).GenerateHandler-fm (6 handlers)
[GIN-debug] POST   /v1/embeddings            --> github.com/ollama/ollama/server.(*Server).EmbedHandler-fm (6 handlers)
[GIN-debug] GET    /v1/models                --> github.com/ollama/ollama/server.(*Server).ListHandler-fm (6 handlers)
[GIN-debug] GET    /v1/models/:model         --> github.com/ollama/ollama/server.(*Server).ShowHandler-fm (6 handlers)
time=2025-03-11T13:00:40.793Z level=INFO source=gpu.go:199 msg="vulkan: load libvulkan and libcap ok"
time=2025-03-11T13:00:40.877Z level=INFO source=gpu.go:421 msg="error looking up vulkan GPU memory" error="device is a CPU"
time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:443 msg="amdgpu detected, but no compatible rocm library found.  Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install"
time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:348 msg="unable to verify rocm library: no suitable rocm found, falling back to CPU"
time=2025-03-11T13:00:40.879Z level=INFO source=types.go:137 msg="inference compute" id=0 library=vulkan variant="" compute=1.3 driver=1.3 name="AMD Radeon Graphics (RADV GFX1103_R1)" total="15.6 GiB" available="15.6 GiB"
```

```
 # ollama run phi4:14b
>>> /set verbose
Set 'verbose' mode.
>>> how's it going?
Hello! I'm here to help you with any questions or tasks you have. How can I assist you today? 😊



total duration:       3.341959745s
load duration:        18.165612ms
prompt eval count:    15 token(s)
prompt eval duration: 475ms
prompt eval rate:     31.58 tokens/s
eval count:           26 token(s)
eval duration:        2.846s
eval rate:            9.14 tokens/s
>>>
```

* This is no longer needed
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Fixes SIGSEGV: segmentation violation running gemma3 models on ollama 0.6.0 #21

Patch provided by McBane87 on https://github.com/whyvl/ollama-vulkan/issues/21

Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Applied 04-disable-mmap-vulkan.patch

From: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871

Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Pulled new upstream code for ggml-bulkan backend
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Merged latest ollama 0.6.2 and nasrally's Flash Attention patches (#5)

* readme: add Ellama to list of community integrations (#9800)

* readme: add screenpipe to community integrations (#9786)

* Add support for ROCm gfx1151 (#9773)

* conditionally enable parallel pipelines

* sample: make mutations in transforms explicit (#9743)

* updated minP to use early exit making use of sorted tokens

* ml/backend/ggml: allocate memory with malloc when loading model (#9822)

* runner: remove cache prompt flag from ollama runner (#9826)

We do not need to bypass the prompt caching in the ollama runner yet, as
only embedding models needed to bypass the prompt caching. When embedding
models are implemented they can skip initializing this cache completely.

* ollamarunner: Check for minBatch of context space when shifting

Models can specify that a group of inputs need to be handled a single
batch. However, context shifting didn't respect this and could trigger
a break anyways. In this case, we should instead trigger a context
shift earlier so that it occurs before the grouped batch.

Note that there still some corner cases:
 - A long prompt that exceeds the context window can get truncated
   in the middle of an image. With the current models, this will
   result in the model not recognizing the image at all, which is
   pretty much the expected result with truncation.
 - The context window is set less than the minimum batch size. The
   only solution to this is to refuse to load the model with these
   settings. However, this can never occur with current models and
   default settings.

Since users are unlikely to run into these scenarios, fixing them is
left as a follow up.

* Applied latest patches from McBane87

See this for details: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2708820861

Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>

* Add ability to enable flash attention on vulkan (#4

)

* discover: add flash attention handling for vulkan
* envconfig: fix typo in config.go

As part of the process some code was refactored and I added a new field
FlashAttention to GpuInfo since the previous solution didn't allow for a
granular check via vulkan extensions. As a side effect, this now allows
for granular per-device FA support checking in other places

---------
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>
Co-authored-by: default avatarzeo <108888572+zeozeozeo@users.noreply.github.com>
Co-authored-by: default avatarLouis Beaumont <louis.beaumont@gmail.com>
Co-authored-by: default avatarDaniel Hiltgen <dhiltgen@users.noreply.github.com>
Co-authored-by: default avatarMichael Yang <mxyng@pm.me>
Co-authored-by: default avatarParth Sareen <parth.sareen@ollama.com>
Co-authored-by: default avatarJeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: default avatarBruce MacDonald <brucewmacdonald@gmail.com>
Co-authored-by: default avatarJesse Gross <jesse@ollama.com>
Co-authored-by: default avatarNikita <50599445+nasrally@users.noreply.github.com>

* Revert Readme changes

* Revert

* Revert changes in amd_linux.go

* Revert changes in amd_linux.go

* Remove flashattention setting gpu.go

* Revert whitespace changes in gpu.go

* Revert changes in transforms_test.go

* Revert changes in runner.go

* Revert changes in Makefile.sync

* Revert some unintented changes in Dockerfile

* Revert vulkan copy changes in Dockerfile

* Update Vulkan Code to de4c07f93783a1a96456a44dc16b9db538ee1618

* Fixed duplicate sync in ggml.go

* Revert changes in ggml.go

* Revert chnages in ggml.go

* enable falsh attention on vulkan

* revert remove parenthesis

* fixed flash attention logic enabling

* vk_check_flash_attention 0 means supported

* Update gpu.go

* Add vulkan to Windows Build script

* Remove commented out code

* Enable Vulkan Flash attention in FlashAttentionSupported

* Fix logging

* Update Vulkan backend to e54d41befcc1575f4c898c5ff4ef43970cead75f

* Removed libcap related code

libcap is not directly related to Vulkan and should be added by its own PR. It adds additional library dependencies for building and also requires users to run setcap or run ollama as root, which is not ideal for easy use

* Fix Unit Test (Add Vulkan Library)

* Add vulkan to TestHomogeneousGPUs
Test

* vulkan: get GPU ID (ollama v0.11.5)
Signed-off-by: default avatarXiaodong Ye <xiaodong.ye@mthreads.com>

* disable mmap for vulkan

* Reduce Changes remove TestHomogeneousGPUs (doesn't exist on master)

* Update vulkan version to the version used in llama.cpp

* rename gpu patch to correct number

* added Vulkan API to get correct Device UUID

current UUID from pipelineCacheUUID does not match CUDA

* Fix GPU ID Patch

* Remove Code not in llama.cpp

* modified UUID code inside ggml

* Fix Patch

* Copied minimal definition from vulkan header

* Fix compile error in Mac

Metal is preferred so we're disabling Vulkan for now

* Removed unused code

Fix linter error in CI

* Fix patches apply

* fixing lint error

* Removed unneeded function call

Somehow removing this call fixed the crashing when Vulkan header was removed

* added missing NL

* Fixed missing members in Vulkan header

also added zero clear for some structs

* Fixed wrong structure ID

* Fixed Vulkan header

More aligned with official header definition now

* buildvulkanAsSeperateFunction

* Vulkan on Windows Test

* temporarly comment out gate to run windows task

* use temporarly windows-latest for build

* Commenting out other presets to build vulkan

* reenable cpu

* commenting out error action stop

* temporarly commenting out rocm

* set vulkan path

* comment out cude for faster turnaround

* correct vulkan install

* correct vulkan silent install

* fixed install command

* revert debugging changes (vulkan builds on windows)

* revert windows-latest

* trying to build vulkan for linux

* temporarly disable cuda and rocm

* try again linux build

* fix version

* trying to fix

* trying again

* trying again

* fix version

* fixed vulkan-sdk name

* try again

* trying again

* try without version number

* try again

* add some more extra

* trying to use version 1.4.313

* revert debugging changes

* Filter out already supported gpus

* revert debug code

* Use runners for GPU discovery

This revamps how we discover GPUs in the system by leveraging the Ollama
runner.  This should eliminate inconsistency between our GPU discovery and the
runners capabilities at runtime, particularly for cases where we try to filter
out unsupported GPUs.  Now the runner does that implicitly based on the actual
device list.  In some cases free VRAM reporting can be unreliable which can
leaad to scheduling mistakes, so this also includes a patch to leverage more
reliable VRAM reporting libraries if available.

Automatic workarounds have been removed as only one GPU leveraged this, which
is now documented. This GPU will soon fall off the support matrix with the next
ROCm bump.

Additional cleanup of the scheduler and discovery packages can be done in the
future once we have switched on the new memory management code, and removed
support for the llama runner.

* timing info for runner

* WIP - wire up Vulkan with the new engine based discovery

Not a complete implementation - free VRAM is better, but not accurate on
windows

* fix - trust the library paths from discovery when starting runner

* fix index bug

* fix vulkan ids to be underlying

* fix - give bootstrapping more time on slow systems

* Test if Vulkan device is supported

* vk_check_flash_attention is not needed (coompat2 coopmapt and scalar implementation exist)

* Handle GGML_VK_VISIBLE_DEVICES

* ask for supported first

* win: fix CPU query buffer handling

Try in a short loop until we get the size right.

* test: harden integration tests for slow start

If the server takes a while to start up, block
tests from starting until it's online to avoid
setting large timeouts in individual test cases.

* gofumpt fix

* fix build

* merge fixes

* merge fixes

* fixed build

* merge fixes

* fixing build

* fixed build

* fixed formatting

* fixed build

* fix vulkan gpu id patch

* sync llama.cpp vulkan code

* update build windows script

* merge fixes

* fix format

* fixed vulkan casing

* handle igpu as gpu

* improve case

* print out unknown library

* rturn Vulkan for vulkan library

* Revert "rturn Vulkan for vulkan library"

This reverts commit 690461a12fd5e93295d174c97edefb2bc33285b1.

* fixed patch number

* return Library Name

* remvoe debug code

* return integrated in vulkan backend

* Return pci Properties

* update patch

* directly get pci proeprties without parsing

* workaround for filtering devices. Correct way is to have a LibraryPosition Parameter in the deviceInfo

* Revert "directly get pci proeprties without parsing"

This reverts commit 8e0624851f5ed7d9f74518f574dfb422e4dd4dc2.

* Set FilteredID for Environment Filtering

* ROCm Library is named ROCm

* revert changes in patch

* Create 0028-vulkan-pci-and-memory.patch

* vulkan memory patch

* casing fix

* Add more pci properties

* Added better memory management

* Added better memory managament

* fixed patch

* Fixed patch

* FilterID creation group by library

* filter out vulkan supported by other gpu

* fixing deviceid compare

* Vulkan Fix FA coopmat1 invalid array indexing

* Use everywhere the same Vulkan Version 1.4.321.1

* Remove unneeded patch

* vulkan update

* sync vulkan glsl files

* only use for vulkan the filteredid (numeric device number)

* simplify code

---------
Signed-off-by: default avatarVadim Grinco <vadim@grinco.eu>
Signed-off-by: default avatarXiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: default avatarpufferffish <github@bandersnatch.anonaddy.com>
Co-authored-by: KOISHI KOMEIJI FROM TOUHOU 11 <fuck>
Co-authored-by: default avatarDSLstandard <qgeneral35@gmail.com>
Co-authored-by: default avatarpufferffish <me@windtfw.com>
Co-authored-by: default avataryeongbba <yeongmo.lee@logpresso.com>
Co-authored-by: default avatartomaThomas <tomathomas@mailbox.org>
Co-authored-by: default avatarAntoine Viallon <antoine@lesviallon.fr>
Co-authored-by: default avatarVadim Grinco <vadim@grinco.eu>
Co-authored-by: default avatarzeo <108888572+zeozeozeo@users.noreply.github.com>
Co-authored-by: default avatarLouis Beaumont <louis.beaumont@gmail.com>
Co-authored-by: default avatarDaniel Hiltgen <dhiltgen@users.noreply.github.com>
Co-authored-by: default avatarMichael Yang <mxyng@pm.me>
Co-authored-by: default avatarParth Sareen <parth.sareen@ollama.com>
Co-authored-by: default avatarJeffrey Morgan <jmorganca@gmail.com>
Co-authored-by: default avatarBruce MacDonald <brucewmacdonald@gmail.com>
Co-authored-by: default avatarJesse Gross <jesse@ollama.com>
Co-authored-by: default avatarNikita <50599445+nasrally@users.noreply.github.com>
Co-authored-by: default avatarMasato Nakasaka <masato.nakasaka@intel.com>
Co-authored-by: default avatarXiaodong Ye <xiaodong.ye@mthreads.com>
Co-authored-by: default avatarDaniel Hiltgen <daniel@ollama.com>
parent fd8aa947
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float x = float(data_a[i]);
data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f)));
}
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#include "rte.glsl"
#include "types.glsl"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint batch_offset; uint offset_delta;
uint IC;
uint IW; uint IH;
uint OW; uint OH;
uint KW; uint KH;
uint pelements;
uint CHW;
int s0; int s1;
int p0; int p1;
int d0; int d1;
} p;
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
const uint NUM_ITER = 512 / BLOCK_SIZE;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint gidx = gl_GlobalInvocationID.x;
const uint oh = gl_GlobalInvocationID.y;
const uint batch = gl_GlobalInvocationID.z / p.IC;
const uint ic = gl_GlobalInvocationID.z % p.IC;
const uint src_base = ic * p.offset_delta + batch * p.batch_offset;
const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH);
const int oh_s1 = int(oh) * p.s1;
const uint ksize = p.OW * p.KH;
const uint base_linear_idx = gidx * NUM_ITER;
uint current_kx = base_linear_idx / ksize;
const uint rem = base_linear_idx - (current_kx * ksize);
uint current_ky = rem / p.OW;
uint current_ix = rem % p.OW;
A_TYPE values[NUM_ITER];
BDA_OFFSET_T offset_dst[NUM_ITER];
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
values[idx] = A_TYPE(0);
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0;
const uint iih = oh_s1 + current_ky * p.d1 - p.p1;
offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx;
if ((iih < p.IH) && (iiw < p.IW)) {
values[idx] = data_a[src_base + iih * p.IW + iiw];
}
if (++current_ix == p.OW) {
current_ix = 0;
if (++current_ky == p.KH) {
current_ky = 0;
current_kx++;
}
}
}
[[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) {
const uint linear_idx = base_linear_idx + idx;
if (linear_idx >= p.pelements) {
continue;
}
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]);
dst_addr.d = D_TYPE(values[idx]);
#else
data_d[offset_dst[idx]] = D_TYPE(values[idx]);
#endif
}
}
#version 450
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_control_flow_attributes : require
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "rte.glsl"
#include "types.glsl"
layout (push_constant) uniform parameter
{
BDA_STORAGE_T dst_addr;
uint32_t nb10;
uint32_t nb11;
uint32_t nb12;
uint32_t nb13;
uint32_t s0;
uint32_t s1;
uint32_t s2;
uint32_t p0;
uint32_t p1;
uint32_t p2;
uint32_t d0;
uint32_t d1;
uint32_t d2;
uint32_t IW;
uint32_t IH;
uint32_t ID;
uint32_t IC;
uint32_t KW;
uint32_t OH;
uint32_t KD_KH_KW;
uint32_t KH_KW;
uint32_t IC_KD_KH_KW;
uint32_t N_OD_OH;
uint32_t OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW;
uint32_t misalign_offsets;
} p;
uint get_aoffset() { return p.misalign_offsets >> 16; }
uint get_doffset() { return p.misalign_offsets & 0xFFFF; }
layout(constant_id = 0) const uint BLOCK_SIZE = 32;
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
#if BDA
layout (buffer_reference) buffer D_ptr {D_TYPE d;};
#endif
void main() {
const uint32_t i = gl_GlobalInvocationID.x;
uint32_t nb10 = p.nb10;
uint32_t nb11 = p.nb11;
uint32_t nb12 = p.nb12;
uint32_t nb13 = p.nb13;
uint32_t s0 = p.s0;
uint32_t s1 = p.s1;
uint32_t s2 = p.s2;
uint32_t p0 = p.p0;
uint32_t p1 = p.p1;
uint32_t p2 = p.p2;
uint32_t d0 = p.d0;
uint32_t d1 = p.d1;
uint32_t d2 = p.d2;
uint32_t IW = p.IW;
uint32_t IH = p.IH;
uint32_t ID = p.ID;
uint32_t IC = p.IC;
uint32_t KW = p.KW;
uint32_t OH = p.OH;
uint32_t KD_KH_KW = p.KD_KH_KW;
uint32_t KH_KW = p.KH_KW;
uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW;
uint32_t N_OD_OH = p.N_OD_OH;
uint32_t OD_OH = p.OD_OH;
uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW;
uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW;
uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW;
if (i >= IC_KD_KH_KW) {
return;
}
const uint32_t iic = i / KD_KH_KW;
const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW;
const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW;
const uint32_t ikw = i % KW;
const uint32_t iow = gl_GlobalInvocationID.y;
for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) {
const uint32_t in_ = iz / OD_OH;
const uint32_t iod = (iz - in_*OD_OH) / OH;
const uint32_t ioh = iz % OH;
const uint32_t iiw = iow * s0 + ikw * d0 - p0;
const uint32_t iih = ioh * s1 + ikh * d1 - p1;
const uint32_t iid = iod * s2 + ikd * d2 - p2;
const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw;
const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10;
#if BDA
D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst);
if (iih >= IH || iiw >= IW || iid >= ID) {
dst_addr.d = D_TYPE(0.0f);
} else {
dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#else
if (iih >= IH || iiw >= IW || iid >= ID) {
data_d[offset_dst + get_doffset()] = D_TYPE(0.0f);
} else {
data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]);
}
#endif
}
}
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
#define BLOCK_SIZE 512
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
shared FLOAT_TYPE sum[BLOCK_SIZE];
void main() {
const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x;
const uint tid = gl_LocalInvocationID.x;
sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]);
sum[tid] += xi * xi;
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
sum[tid] += sum[tid + s];
}
barrier();
}
const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1)));
[[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) {
data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col]));
}
}
#version 450
#include "generic_head.glsl"
#include "types.glsl"
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer X {A_TYPE data_a[];};
layout (binding = 1) writeonly buffer D {D_TYPE data_d[];};
void main() {
const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x;
if (i >= p.KX) {
return;
}
const float val = float(data_a[i]);
data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1);
}
#version 450
#include "types.glsl"
#include "generic_binary_head.glsl"
const uint num_threads = 256;
layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in;
void main() {
uint idx = get_idx();
// num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation
const uint num_iter = 2;
[[unroll]] for (uint i = 0; i < num_iter; ++i) {
if (idx >= p.ne) {
continue;
}
uint i00, i01, i02, i03;
get_indices(idx, i00, i01, i02, i03);
data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]));
idx += num_threads;
}
}
#version 450
#extension GL_EXT_control_flow_attributes : enable
layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {float data_a[];};
layout (binding = 0) readonly buffer A4 {vec4 data_a4[];};
layout (binding = 1) writeonly buffer D {float data_d[];};
layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];};
layout (push_constant) uniform parameter {
uint ne;
uint k_num;
} p;
void main() {
// Each invocation handles four consecutive components
const uint idx = gl_GlobalInvocationID.x * 4;
if (idx >= p.ne) {
return;
}
// Check if all four components are in bounds and aligned,
// then use vector loads
if (idx + 3 < p.ne && (p.ne % 4) == 0) {
vec4 result = vec4(0.0f);
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a4[(i * p.ne + idx) / 4];
}
data_d4[idx / 4] = result;
} else {
[[unroll]] for (uint j = 0; j < 4; ++j) {
if (idx + j < p.ne) {
float result = 0.0f;
[[unroll]] for (uint i = 0; i < p.k_num; i++) {
result += data_a[i * p.ne + idx + j];
}
data_d[idx + j] = result;
}
}
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16)
#define K_PER_ITER 8
#else
#define K_PER_ITER 2
#endif
uint a_offset, b_offset, d_offset, y_offset;
void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter)
{
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const uint col = i*BLOCK_SIZE + K_PER_ITER*tid;
const uint iqs = (col%QUANT_K)/QUANT_R; // quant index
const uint iybs = col - col%QUANT_K; // y block start index
#if K_PER_ITER == 8
#if QUANT_R == 2
const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]);
const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y);
const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w);
#else
const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]);
const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]);
#endif
#else
// Check if the second of the pair of elements is OOB, and don't fetch B or
// accumulate it. We still fetch a pair of elements for A, which is fine for
// quantized formats since they'll be within the same block. We should
// probably skip fetching the second element for F16/F32, but as of now we
// still do.
const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols);
FLOAT_TYPE b0 = 0, b1 = 0;
b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]);
if (!OOB) {
b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]);
}
#endif
uint ibi = first_row*p.ncols;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib = (ibi + col)/QUANT_K; // block index
ibi += p.ncols;
#if K_PER_ITER == 8
vec4 v = dequantize4(ib, iqs, a_offset);
vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset);
const vec2 dm = get_dm(ib, a_offset);
if (dm.y != 0) { // quant has min component
v = v * dm.x + dm.y;
v2 = v2 * dm.x + dm.y;
}
// matrix multiplication
FLOAT_TYPE rowtmp = dot(bv0, v);
rowtmp += dot(bv1, v2);
if (dm.y == 0)
rowtmp *= dm.x;
temp[j][n] += rowtmp;
#else
const vec2 v = dequantize(ib, iqs, a_offset);
// matrix multiplication
temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]);
if (!OOB) {
temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]);
}
#endif
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
const uint tid = gl_LocalInvocationID.x;
get_offsets(a_offset, b_offset, d_offset);
a_offset /= QUANT_K;
y_offset = QUANT_R == 1 ? 1 : QUANT_K/2;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE);
if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) {
num_iters++;
}
int unroll_count = 4;
uint unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
// If the K dimension is odd, we need lastiter==true on the last iteration
// so OOB is computed correctly. Skip some unrolling to make that happen.
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
uint i = 0;
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
unroll_count = 2;
unrolled_iters = num_iters & ~(unroll_count - 1);
#if K_PER_ITER == 2
if ((p.ncols & 1) != 0 &&
unrolled_iters == num_iters &&
unrolled_iters > 0) {
unrolled_iters -= unroll_count;
}
#endif
while (i < unrolled_iters) {
// Manually partially unroll the loop
[[unroll]] for (uint k = 0; k < unroll_count; ++k) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false);
i++;
}
}
while (i < num_iters) {
iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true);
i++;
}
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
#ifdef NEEDS_INIT_IQ_SHMEM
init_iq_shmem(gl_WorkGroupSize);
#endif
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#extension GL_EXT_shader_8bit_storage : require
#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM
#extension GL_KHR_shader_subgroup_basic : require
#extension GL_KHR_shader_subgroup_arithmetic : require
#endif
#ifdef MUL_MAT_ID
#define EXPERT_COUNT 8
#endif
#include "types.glsl"
#ifndef MMQ
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
#else
layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];};
#endif
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
#ifdef B_TYPE_VEC2
layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];};
#endif
#ifdef B_TYPE_VEC4
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
#endif
layout (binding = 2) writeonly buffer D {D_TYPE data_d[];};
#ifdef MUL_MAT_ID
layout (binding = 3) readonly buffer IDS {int data_ids[];};
#endif
#include "dequant_funcs.glsl"
layout (push_constant) uniform parameter
{
uint ncols;
uint stride_a;
uint stride_b;
uint stride_d;
uint batch_stride_a;
uint batch_stride_b;
uint batch_stride_d;
#ifdef MUL_MAT_ID
uint nei0;
uint ne11;
#else
uint ne02;
uint ne12;
uint broadcast2;
uint broadcast3;
#endif
} p;
void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) {
#ifdef MUL_MAT_ID
const uint expert_idx = gl_GlobalInvocationID.y;
#else
const uint batch_idx = gl_GlobalInvocationID.y;
#endif
#ifndef MUL_MAT_ID
uint batch_idx_a = 0;
if (batch_idx != 0) {
const uint i13 = batch_idx / p.ne12;
const uint i12 = batch_idx % p.ne12;
const uint i03 = i13 / p.broadcast3;
const uint i02 = i12 / p.broadcast2;
batch_idx_a = i03 * p.ne02 + i02;
}
#else
const uint expert_id = data_ids[expert_idx];
#endif
a_offset =
#ifdef MUL_MAT_ID
expert_id * p.batch_stride_a;
#else
batch_idx_a * p.batch_stride_a;
#endif
b_offset =
#ifdef MUL_MAT_ID
(expert_idx % p.ne11) * p.stride_b;
#else
batch_idx * p.batch_stride_b;
#endif
d_offset =
#ifdef MUL_MAT_ID
expert_idx * p.stride_d;
#else
batch_idx * p.batch_stride_d;
#endif
}
layout (constant_id = 0) const uint BLOCK_SIZE = 32;
layout (constant_id = 1) const uint NUM_ROWS = 1;
layout (constant_id = 2) const uint NUM_COLS = 1;
#ifdef USE_SUBGROUP_ADD_NO_SHMEM
void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
}
#else
shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE];
void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) {
// subgroupAdd is probably faster on devices that support it,
// particularly when the workgroup has more than one subgroup
#if USE_SUBGROUP_ADD
// sum up partial sums within a subgroup
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = subgroupAdd(temp[j][n]);
}
}
// Go through shared memory to sum partials across subgroups
if (gl_SubgroupInvocationID == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][gl_SubgroupID] = temp[j][n];
}
}
}
barrier();
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
temp[j][n] = FLOAT_TYPE(0);
[[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) {
temp[j][n] += tmpsh[j][n][s];
}
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]);
}
}
}
#else
// sum up partial sums and write back result
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] = temp[j][n];
}
}
barrier();
[[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
tmpsh[j][n][tid] += tmpsh[j][n][tid + s];
}
}
}
barrier();
}
if (tid == 0) {
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]);
}
}
}
#endif
}
#endif
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint16_t[4] scales = data_a[ibi].scales;
const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12;
const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x);
const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1));
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1));
const uint qs = data_a[ibi].qs[4 * ib32 + l];
const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA;
const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1);
const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int k = 0; k < 4; ++k) {
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
}
temp[j][n] = fma(dl, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint qh = data_a[ibi].qh[ib32];
const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1);
const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA;
[[unroll]] for (uint l = 0; l < 4; ++l) {
const uint qs = data_a[ibi].qs[4 * ib32 + l];
const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3);
const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int k = 0; k < 4; ++k) {
sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta,
fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum));
}
temp[j][n] = fma(dl, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
const uint qh = data_a[ibi].qh[ib32];
const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147
const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy;
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint8_t sign = sign16[l];
const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300);
const uvec2 grid = iq2s_grid[qs];
const vec4 grid0 = vec4(unpack8(grid.x));
const vec4 grid1 = vec4(unpack8(grid.y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint nibble_shift = 4 * (itid & 1);
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF;
const float db = d * (0.5 + scale) * 0.25;
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[2 * itid + l];
const uint sign = qs >> 9;
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x));
const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(
data_a_packed16[ibi].qs[4 * ib32 + 2],
data_a_packed16[ibi].qs[4 * ib32 + 3]));
const float db = d * 0.25 * (0.5 + (signscale >> 28));
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l];
const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x));
const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 32 * ib32;
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF;
const float dscale = d * (1 + 2 * scale);
const uint qh = data_a[ibi].qh[ib32];
FLOAT_TYPE sum[NUM_COLS];
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
sum[j] = 0.0;
}
[[unroll]] for (uint l = 0; l < 4; ++l) {
const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147
const uint sign = data_a[ibi].signs[4 * ib32 + l];
const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)]));
const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)]));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
sum[j] =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w),
sum[j]))))))));
}
}
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
temp[j][n] = fma(dscale, sum[j], temp[j][n]);
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 8 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/8;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 8; // 0...7
const uint ix = tid / 8;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) {
const uint y_idx = i * QUANT_K + 16 * itid;
const uint ib32 = itid / 2; // 0..7
uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const float d = float(data_a[ibi].d);
const uint signscale = pack32(u16vec2(
data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32],
data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1]));
const float db = d * 0.5 * (0.5 + (signscale >> 28));
[[unroll]] for (uint l = 0; l < 2; ++l) {
const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l];
const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1];
const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7);
const uint sign7 = bitCount(sign);
const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0]));
const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1]));
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]);
const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]);
FLOAT_TYPE sum =
fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x),
fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y),
fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z),
fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w),
fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x),
fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y),
fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z),
fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w),
FLOAT_TYPE(0.0)))))))));
temp[j][n] = fma(db, sum, temp[j][n]);
}
}
ibi += num_blocks_per_row;
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint blocks_per_wg = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid % 16; // 0...15
const uint ix = tid / 16;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
[[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg)
calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
init_iq_shmem(gl_WorkGroupSize);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#define BLOCK_SIZE 32
#define FLOAT_TYPE float
layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout (push_constant) uniform parameter
{
uint ncols_x;
uint nrows_x;
uint row_stride_x;
uint channel_stride_x;
uint channel_stride_y;
uint channel_x_divisor;
uint ne12;
uint b_offset;
uint d_offset;
uint nb03;
uint nb13;
uint nb23;
} p;
shared FLOAT_TYPE tmp[BLOCK_SIZE];
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
const uint channel = gl_GlobalInvocationID.z;
const uint i3 = gl_WorkGroupID.x;
const uint channel_x = channel / p.channel_x_divisor;
const uint channel_y = channel % p.ne12;
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst;
FLOAT_TYPE temp = 0.0f;
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x;) {
// Unroll 2x and do vec4 loads if aligned
const uint unroll_count = 2;
if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) {
[[unroll]] for (uint i = 0; i < unroll_count; ++i) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
}
// do vec4 loads if aligned
} else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
const uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
const vec4 bv4 = vec4(data_b_v4[iy / 4]);
temp += dot(av4, bv4);
col_x0 += 4*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
const uint row_y = col_x;
const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x;
const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp);
col_x0 += BLOCK_SIZE;
}
}
tmp[tid] = temp;
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
tmp[tid] += tmp[tid + s];
}
barrier();
}
if (tid == 0) {
dst[idst] = tmp[0];
}
}
#version 450
#extension GL_EXT_control_flow_attributes : enable
#extension GL_EXT_shader_16bit_storage : require
#if USE_SUBGROUP_ADD
#extension GL_KHR_shader_subgroup_arithmetic : enable
#endif
#define FLOAT_TYPE float
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
layout (binding = 0) readonly buffer A {A_TYPE data_a[];};
layout (binding = 1) readonly buffer B {B_TYPE data_b[];};
layout (binding = 2) writeonly buffer D {D_TYPE dst[];};
layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];};
layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];};
layout(constant_id = 0) const int BLOCK_SIZE = 32;
// gqa_ratio is in the range [1,8]
layout(constant_id = 1) const uint gqa_ratio = 1;
layout (push_constant) uniform parameter
{
uint ncols_x;
uint nrows_x;
uint nchannels_x;
uint nchannels_y;
uint b_offset;
uint d_offset;
} p;
#if !USE_SUBGROUP_ADD
shared FLOAT_TYPE tmp[8][BLOCK_SIZE];
#endif
void main() {
const uint tid = gl_LocalInvocationID.x;
const uint row_x = gl_GlobalInvocationID.y;
uint channel, channel_x;
// When gqa_ratio > 1, each invocation does multiple rows.
// The row in the A matrix is starting from channel / gqa_ratio and the
// rows in the B matrix are [channel, channel+gqa_ratio).
// When gpa_ratio is 1, each invocation does one row.
if (gqa_ratio > 1) {
channel_x = gl_GlobalInvocationID.z;
channel = channel_x * gqa_ratio;
} else {
channel = gl_GlobalInvocationID.z;
channel_x = channel / (p.nchannels_y / p.nchannels_x);;
}
const uint nrows_y = p.ncols_x;
const uint nrows_dst = p.nrows_x;
const uint row_dst = row_x;
FLOAT_TYPE temp[8];
[[unroll]] for (uint i = 0; i < 8; ++i) {
temp[i] = FLOAT_TYPE(0.0f);
}
// Detect alignment for vector loads
bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0;
for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) {
// Use vec4 loads if aligned
if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) {
uint col_x = col_x0 + 4*tid;
const uint row_y = col_x;
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const vec4 av4 = vec4(data_a_v4[ix / 4]);
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
vec4 bv4 = data_b_v4[iy / 4];
temp[c] += dot(av4, bv4);
}
col_x0 += 3*BLOCK_SIZE;
} else {
const uint col_x = col_x0 + tid;
if (col_x >= p.ncols_x) {
break;
}
// x is transposed and permuted
const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x;
const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]);
const uint row_y = col_x;
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// y is not transposed but permuted
const uint iy = (channel + c)*nrows_y + row_y;
temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]);
}
}
}
#if USE_SUBGROUP_ADD
// reduce vec4 at a time
vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]);
t = subgroupAdd(t);
temp[0] = t[0];
temp[1] = t[1];
temp[2] = t[2];
temp[3] = t[3];
if (gqa_ratio > 4) {
t = vec4(temp[4], temp[5], temp[6], temp[7]);
t = subgroupAdd(t);
temp[4] = t[0];
temp[5] = t[1];
temp[6] = t[2];
temp[7] = t[3];
}
#else
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
tmp[c][tid] = temp[c];
}
// sum up partial sums and write back result
barrier();
[[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) {
if (tid < s) {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] += tmp[c][tid + s];
tmp[c][tid] = temp[c];
}
}
barrier();
}
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
temp[c] = tmp[c][tid];
}
#endif
if (tid == 0) {
[[unroll]] for (uint c = 0; c < gqa_ratio; ++c) {
// dst is not transposed and not permuted
const uint idst = (channel + c)*nrows_dst + row_dst;
dst[idst] = temp[c];
}
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16];
shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row) {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
}
barrier();
if (i >= num_blocks_per_row)
continue;
} else {
const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]);
sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF);
sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF);
barrier();
}
const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16);
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
vec2 d = vec2(data_a[ib0 + i].d);
const FLOAT_TYPE dall = FLOAT_TYPE(d.x);
const FLOAT_TYPE dmin = FLOAT_TYPE(d.y);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum1 = FLOAT_TYPE(0.0);
FLOAT_TYPE sum2 = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ],
fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2],
fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ],
fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2],
fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ],
fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2],
fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ],
fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1))))))));
sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im],
fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im],
fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im],
fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im],
fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im],
fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im],
fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im],
fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2))))))));
}
temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n]));
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_in = itid - 8*v_im; // 0...7
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
#version 450
#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require
#include "mul_mat_vec_base.glsl"
layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in;
shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8];
FLOAT_TYPE temp[NUM_COLS][NUM_ROWS];
uint csel = 0;
void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) {
const uint y_idx = i * QUANT_K + y_offset;
[[unroll]] for (uint n = 0; n < num_rows; ++n) {
const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row;
csel ^= 1;
if (!all_threads) { // when we don't have enough blocks to use all threads
if (i < num_blocks_per_row)
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
if (i >= num_blocks_per_row)
continue;
}
const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16));
const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2));
const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2));
const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2));
const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2));
// 0, 1, 16, 17
uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8);
qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16;
const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303));
const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303));
const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303));
const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303));
if (all_threads) {
sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32);
barrier();
}
const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d);
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]);
vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]);
vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]);
vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]);
vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]);
vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]);
vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]);
vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]);
FLOAT_TYPE sum = FLOAT_TYPE(0.0);
[[unroll]] for (int l = 0; l < 2; ++l) {
sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ],
fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2],
fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ],
fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2],
fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ],
fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2],
fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ],
fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum))))))));
}
temp[j][n] = fma(d, sum, temp[j][n]);
}
}
}
void compute_outputs(const uint32_t first_row, const uint32_t num_rows) {
uint a_offset, b_offset, d_offset;
get_offsets(a_offset, b_offset, d_offset);
const uint num_blocks_per_row = p.ncols / QUANT_K;
// 16 threads are used to process each block
const uint it_size = gl_WorkGroupSize.x/16;
const uint tid = gl_LocalInvocationID.x;
const uint itid = tid%16; // 0...15
const uint ix = tid/16;
const uint itid8 = itid%8;
const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128...
const uint v_im4 = v_im*4;
const uint v_in = itid - 8*v_im; // 0...7
const uint32_t m = 0x01010101 << (4 * v_im);
uint32_t hm_m[4];
[[unroll]] for (uint j = 0; j < 4; ++j)
hm_m[j] = m << j;
const uint l0 = 2*v_in; // 0...15
const uint q_offset = 32*v_im + l0;
const uint y_offset = 128*v_im + l0;
[[unroll]] for (uint j = 0; j < NUM_COLS; ++j) {
[[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) {
temp[j][i] = FLOAT_TYPE(0);
}
}
const uint s_shift = v_im4 + 2*(itid8/4);
const uint nbr_par_th = num_blocks_per_row%it_size;
const uint nbr_all_th = num_blocks_per_row - nbr_par_th;
uint i0 = 0;
[[unroll]] for (; i0 < nbr_all_th; i0 += it_size)
calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true);
calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false);
reduce_result(temp, d_offset, first_row, num_rows, tid);
}
void main() {
const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z);
// do NUM_ROWS at a time, unless there aren't enough remaining rows
if (first_row + NUM_ROWS <= p.stride_d) {
compute_outputs(first_row, NUM_ROWS);
} else {
if (first_row >= p.stride_d) {
return;
}
compute_outputs(first_row, p.stride_d - first_row);
}
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment