Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
ollama
Commits
1dc4b857
Commit
1dc4b857
authored
Mar 06, 2025
by
xuxzh1
🎱
Browse files
opt rms_norm
parent
23a7a73f
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
76 additions
and
1 deletion
+76
-1
llama/ggml-cuda/norm.cu
llama/ggml-cuda/norm.cu
+76
-1
No files found.
llama/ggml-cuda/norm.cu
View file @
1dc4b857
...
...
@@ -157,6 +157,81 @@ static __global__ void __launch_bounds__(1024) rms_norm_f32(const float * x, flo
}
}
using
floatx4
=
__attribute__
((
__vector_size__
(
4
*
sizeof
(
float
))))
float
;
template
<
typename
T
,
int
VEC
,
int
NUM_WARPS
>
__inline__
__device__
T
BlockReduceSumVEC
(
T
&
val
,
T
*
shared
)
{
#pragma unroll
for
(
int
offset
=
32
;
offset
>
0
;
offset
>>=
1
)
{
val
+=
__shfl_xor_sync
(
0xffffffff
,
val
,
offset
,
64
);
//64
}
if
constexpr
(
1
<
NUM_WARPS
)
{
const
int
tid
=
threadIdx
.
x
;
const
int
lid
=
tid
%
64
;
const
int
wid
=
tid
/
64
;
if
(
lid
==
0
)
{
shared
[
wid
]
=
val
;
}
__syncthreads
();
if
(
wid
==
0
&&
lid
<
NUM_WARPS
)
{
#pragma unroll
for
(
int
offset
=
NUM_WARPS
/
2
;
offset
>
0
;
offset
>>=
1
)
{
shared
[
lid
]
+=
__shfl_xor_sync
(
0xffffffff
,
shared
[
lid
],
offset
,
64
);
//64
}
val
=
shared
[
lid
];
}
}
return
val
;
}
template
<
int
block_size
,
int
VEC
=
4
>
static
__global__
void
__launch_bounds__
(
1024
)
rms_norm_f32_opt1
(
const
float
*
x
,
float
*
dst
,
const
int
ncols
,
const
float
eps
)
{
const
int
row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
const
int
tid
=
threadIdx
.
x
;
constexpr
int
NUM_WARPS
=
block_size
/
64
;
__shared__
float
lds_sum
[
NUM_WARPS
*
4
];
__shared__
float
sum_val
;
float
tmp
=
0.0
f
;
// partial sum for thread in warp floatx4
floatx4
xi_vec
;
for
(
int
col
=
tid
*
VEC
;
col
<
ncols
;
col
+=
block_size
*
VEC
)
{
xi_vec
=
*
(
floatx4
*
)(
x
+
row
*
ncols
+
col
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
++
i
)
{
tmp
+=
xi_vec
[
i
]
*
xi_vec
[
i
];
}
}
tmp
=
BlockReduceSumVEC
<
float
,
VEC
,
NUM_WARPS
>
(
tmp
,
lds_sum
);
// tmp = __shfl_sync(0xffffffff, tmp, 0); //lds or shfl
if
(
tid
==
0
)
sum_val
=
rsqrtf
(
tmp
/
ncols
+
eps
);
__syncthreads
();
float
scale
=
sum_val
;
//重复利用寄存器访存
for
(
int
col
=
tid
*
VEC
;
col
<
ncols
;
col
+=
block_size
*
VEC
)
{
xi_vec
=
*
(
floatx4
*
)(
x
+
row
*
ncols
+
col
);
#pragma unroll
for
(
int
i
=
0
;
i
<
VEC
;
++
i
){
xi_vec
[
i
]
=
xi_vec
[
i
]
*
scale
;
}
*
(
floatx4
*
)(
dst
+
row
*
ncols
+
col
)
=
xi_vec
;
}
}
static
void
norm_f32_cuda
(
const
float
*
x
,
float
*
dst
,
const
int
ncols
,
const
int
nrows
,
const
float
eps
,
cudaStream_t
stream
)
{
GGML_ASSERT
(
ncols
%
WARP_SIZE
==
0
);
if
(
ncols
<
1024
)
{
...
...
@@ -185,7 +260,7 @@ static void rms_norm_f32_cuda(const float * x, float * dst, const int ncols, con
rms_norm_f32
<
WARP_SIZE
><<<
nrows
,
block_dims
,
0
,
stream
>>>
(
x
,
dst
,
ncols
,
eps
);
}
else
{
const
dim3
block_dims
(
1024
,
1
,
1
);
rms_norm_f32
<
1024
><<<
nrows
,
block_dims
,
0
,
stream
>>>
(
x
,
dst
,
ncols
,
eps
);
rms_norm_f32
_opt1
<
1024
><<<
nrows
,
block_dims
,
0
,
stream
>>>
(
x
,
dst
,
ncols
,
eps
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment