Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
chenpangpang
base-image-test
Commits
e0c65b0d
Commit
e0c65b0d
authored
Jan 13, 2026
by
chenpangpang
Browse files
Add new file
parent
b2d7fb87
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
259 additions
and
0 deletions
+259
-0
jax/test.py
jax/test.py
+259
-0
No files found.
jax/test.py
0 → 100644
View file @
e0c65b0d
import
jax
import
jax.numpy
as
jnp
from
jax
import
random
,
grad
,
jit
,
vmap
import
time
# 检查JAX是否使用了GPU
print
(
"JAX版本:"
,
jax
.
__version__
)
print
(
"可用设备:"
,
jax
.
devices
())
print
(
"当前使用的设备:"
,
jax
.
devices
()[
0
])
print
(
"设备类型:"
,
jax
.
devices
()[
0
].
device_kind
)
# 1. 基本的GPU矩阵运算示例
def
basic_gpu_operations
():
print
(
"
\n
=== 基础GPU矩阵运算 ==="
)
# 创建大型矩阵(这些会自动在GPU上分配)
key
=
random
.
PRNGKey
(
42
)
size
=
5000
# 大型矩阵
# 生成随机矩阵
A
=
random
.
normal
(
key
,
(
size
,
size
))
B
=
random
.
normal
(
key
,
(
size
,
size
))
# GPU矩阵乘法
start
=
time
.
time
()
C
=
jnp
.
dot
(
A
,
B
)
# 这会自动在GPU上执行
compute_time
=
time
.
time
()
-
start
print
(
f
"矩阵大小:
{
size
}
x
{
size
}
"
)
print
(
f
"GPU矩阵乘法耗时:
{
compute_time
:.
4
f
}
秒"
)
# 检查结果
print
(
f
"结果矩阵C的形状:
{
C
.
shape
}
"
)
print
(
f
"C的平均值:
{
jnp
.
mean
(
C
):.
6
f
}
"
)
print
(
f
"C的标准差:
{
jnp
.
std
(
C
):.
6
f
}
"
)
return
C
# 2. 使用JIT编译加速
def
jit_acceleration
():
print
(
"
\n
=== JIT编译加速示例 ==="
)
# 定义一个复杂函数
def
complex_function
(
x
):
return
jnp
.
sum
(
jnp
.
sin
(
x
)
*
jnp
.
exp
(
-
x
**
2
)
+
jnp
.
log1p
(
jnp
.
abs
(
x
)))
# 创建输入数据
key
=
random
.
PRNGKey
(
0
)
x
=
random
.
normal
(
key
,
(
10000
,
1000
))
# 不使用JIT
start
=
time
.
time
()
result_no_jit
=
complex_function
(
x
)
time_no_jit
=
time
.
time
()
-
start
# 使用JIT编译
jitted_function
=
jit
(
complex_function
)
# 第一次调用会有编译开销
start
=
time
.
time
()
result_jit
=
jitted_function
(
x
)
time_first_call
=
time
.
time
()
-
start
# 第二次调用(已经编译)
start
=
time
.
time
()
result_jit
=
jitted_function
(
x
)
time_second_call
=
time
.
time
()
-
start
print
(
f
"不使用JIT:
{
time_no_jit
:.
4
f
}
秒"
)
print
(
f
"JIT首次调用(含编译):
{
time_first_call
:.
4
f
}
秒"
)
print
(
f
"JIT后续调用:
{
time_second_call
:.
4
f
}
秒"
)
print
(
f
"加速比:
{
time_no_jit
/
time_second_call
:.
2
f
}
倍"
)
return
result_jit
# 3. 自动微分和GPU计算
def
autodiff_gpu
():
print
(
"
\n
=== GPU上的自动微分 ==="
)
# 定义损失函数
def
loss
(
params
,
x
,
y
):
w
,
b
=
params
predictions
=
jnp
.
dot
(
x
,
w
)
+
b
return
jnp
.
mean
((
predictions
-
y
)
**
2
)
# 生成模拟数据
key
=
random
.
PRNGKey
(
1
)
key1
,
key2
,
key3
=
random
.
split
(
key
,
3
)
n_samples
=
10000
n_features
=
100
X
=
random
.
normal
(
key1
,
(
n_samples
,
n_features
))
true_w
=
random
.
normal
(
key2
,
(
n_features
,
1
))
true_b
=
random
.
normal
(
key3
,
(
1
,))
Y
=
jnp
.
dot
(
X
,
true_w
)
+
true_b
+
0.1
*
random
.
normal
(
key1
,
(
n_samples
,
1
))
# 初始化参数
params
=
(
random
.
normal
(
key1
,
(
n_features
,
1
)),
jnp
.
zeros
((
1
,)))
# 计算梯度的GPU函数
grad_loss
=
jit
(
grad
(
loss
,
argnums
=
0
))
# 在GPU上计算梯度
start
=
time
.
time
()
gradients
=
grad_loss
(
params
,
X
,
Y
)
grad_time
=
time
.
time
()
-
start
print
(
f
"数据形状: X=
{
X
.
shape
}
, Y=
{
Y
.
shape
}
"
)
print
(
f
"GPU梯度计算耗时:
{
grad_time
:.
4
f
}
秒"
)
print
(
f
"权重梯度形状:
{
gradients
[
0
].
shape
}
"
)
print
(
f
"偏置梯度形状:
{
gradients
[
1
].
shape
}
"
)
return
gradients
# 4. 批量处理示例
def
batch_processing
():
print
(
"
\n
=== GPU批量处理 ==="
)
# 定义一个处理单个样本的函数
def
process_sample
(
x
):
# 一些复杂的变换
return
jnp
.
sum
(
jnp
.
tanh
(
x
)
*
jnp
.
cos
(
x
))
# 使用vmap进行向量化
batched_process
=
vmap
(
process_sample
)
# 创建批量数据
key
=
random
.
PRNGKey
(
3
)
batch_size
=
10000
sample_size
=
1000
batch_data
=
random
.
normal
(
key
,
(
batch_size
,
sample_size
))
# 逐个处理(慢)
def
sequential_process
(
data
):
results
=
[]
for
i
in
range
(
data
.
shape
[
0
]):
results
.
append
(
process_sample
(
data
[
i
]))
return
jnp
.
array
(
results
)
# 计时
start
=
time
.
time
()
sequential_results
=
sequential_process
(
batch_data
)
sequential_time
=
time
.
time
()
-
start
# 使用vmap批量处理(在GPU上并行)
start
=
time
.
time
()
batched_results
=
batched_process
(
batch_data
)
batched_time
=
time
.
time
()
-
start
print
(
f
"批量大小:
{
batch_size
}
, 样本大小:
{
sample_size
}
"
)
print
(
f
"逐个处理耗时:
{
sequential_time
:.
4
f
}
秒"
)
print
(
f
"批量处理耗时:
{
batched_time
:.
4
f
}
秒"
)
print
(
f
"加速比:
{
sequential_time
/
batched_time
:.
2
f
}
倍"
)
# 验证结果一致性
error
=
jnp
.
max
(
jnp
.
abs
(
sequential_results
-
batched_results
))
print
(
f
"结果最大误差:
{
error
:.
10
f
}
"
)
return
batched_results
# 5. 高级示例:简单的神经网络层
def
neural_network_example
():
print
(
"
\n
=== GPU上的神经网络层 ==="
)
# 定义一个简单的神经网络层
def
dense_layer
(
params
,
x
):
w
,
b
=
params
return
jnp
.
dot
(
x
,
w
)
+
b
# ReLU激活函数
def
relu
(
x
):
return
jnp
.
maximum
(
0
,
x
)
# 两层神经网络
def
two_layer_net
(
params
,
x
):
w1
,
b1
,
w2
,
b2
=
params
hidden
=
relu
(
jnp
.
dot
(
x
,
w1
)
+
b1
)
output
=
jnp
.
dot
(
hidden
,
w2
)
+
b2
return
output
# 初始化参数
key
=
random
.
PRNGKey
(
4
)
key1
,
key2
,
key3
,
key4
=
random
.
split
(
key
,
4
)
input_size
=
784
# MNIST图像大小
hidden_size
=
128
output_size
=
10
# 10个类别
# 初始化权重
w1
=
random
.
normal
(
key1
,
(
input_size
,
hidden_size
))
*
jnp
.
sqrt
(
2.0
/
input_size
)
b1
=
jnp
.
zeros
(
hidden_size
)
w2
=
random
.
normal
(
key2
,
(
hidden_size
,
output_size
))
*
jnp
.
sqrt
(
2.0
/
hidden_size
)
b2
=
jnp
.
zeros
(
output_size
)
params
=
(
w1
,
b1
,
w2
,
b2
)
# 创建批量输入
batch_size
=
1024
x_batch
=
random
.
normal
(
key3
,
(
batch_size
,
input_size
))
# JIT编译神经网络
jitted_net
=
jit
(
two_layer_net
)
# 前向传播
start
=
time
.
time
()
outputs
=
jitted_net
(
params
,
x_batch
)
inference_time
=
time
.
time
()
-
start
print
(
f
"网络架构:
{
input_size
}
->
{
hidden_size
}
->
{
output_size
}
"
)
print
(
f
"批量大小:
{
batch_size
}
"
)
print
(
f
"GPU推理耗时:
{
inference_time
:.
6
f
}
秒"
)
print
(
f
"输出形状:
{
outputs
.
shape
}
"
)
print
(
f
"每个样本的平均输出:
{
jnp
.
mean
(
outputs
,
axis
=
0
)
}
"
)
return
outputs
def
main
():
print
(
"="
*
50
)
print
(
"JAX GPU Demo - 开始执行"
)
print
(
"="
*
50
)
# 运行所有示例
try
:
# 示例1: 基础GPU运算
result1
=
basic_gpu_operations
()
# 示例2: JIT加速
result2
=
jit_acceleration
()
# 示例3: 自动微分
result3
=
autodiff_gpu
()
# 示例4: 批量处理
result4
=
batch_processing
()
# 示例5: 神经网络
result5
=
neural_network_example
()
print
(
"
\n
"
+
"="
*
50
)
print
(
"所有示例执行完成!"
)
print
(
"="
*
50
)
# 验证GPU确实在工作
print
(
"
\n
验证GPU使用情况:"
)
with
jax
.
profiler
.
trace
(
"/tmp/jax-trace"
,
create_perfetto_link
=
True
):
# 执行一个操作来查看跟踪
test_matrix
=
jnp
.
ones
((
1000
,
1000
))
test_result
=
jnp
.
dot
(
test_matrix
,
test_matrix
.
T
)
print
(
f
"跟踪测试完成,结果形状:
{
test_result
.
shape
}
"
)
except
Exception
as
e
:
print
(
f
"执行出错:
{
e
}
"
)
print
(
"
\n
注意:请确保已经正确安装JAX GPU版本"
)
print
(
"安装命令: pip install --upgrade 'jax[cuda12_pip]' -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
)
if
__name__
==
"__main__"
:
main
()
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment